diff options
Diffstat (limited to 'setuptools')
83 files changed, 7188 insertions, 2090 deletions
diff --git a/setuptools/__init__.py b/setuptools/__init__.py index 4d9b8357..9d6f0bc0 100644 --- a/setuptools/__init__.py +++ b/setuptools/__init__.py @@ -21,21 +21,20 @@ from . import monkey __all__ = [ - 'setup', 'Distribution', 'Command', 'Extension', 'Require', + 'setup', + 'Distribution', + 'Command', + 'Extension', + 'Require', 'SetuptoolsDeprecationWarning', - 'find_packages', 'find_namespace_packages', + 'find_packages', + 'find_namespace_packages', ] __version__ = setuptools.version.__version__ bootstrap_install_from = None -# If we run 2to3 on .py files, should we also convert docstrings? -# Default: yes; assume that we can detect doctests reliably -run_2to3_on_doctests = True -# Standard package names for fixer packages -lib2to3_fixer_packages = ['lib2to3.fixes'] - class PackageFinder: """ @@ -60,10 +59,13 @@ class PackageFinder: shell style wildcard patterns just like 'exclude'. """ - return list(cls._find_packages_iter( - convert_path(where), - cls._build_filter('ez_setup', '*__pycache__', *exclude), - cls._build_filter(*include))) + return list( + cls._find_packages_iter( + convert_path(where), + cls._build_filter('ez_setup', '*__pycache__', *exclude), + cls._build_filter(*include), + ) + ) @classmethod def _find_packages_iter(cls, where, exclude, include): @@ -82,7 +84,7 @@ class PackageFinder: package = rel_path.replace(os.path.sep, '.') # Skip directory trees that are not valid packages - if ('.' in dir or not cls._looks_like_package(full_path)): + if '.' in dir or not cls._looks_like_package(full_path): continue # Should this package be included? @@ -125,12 +127,10 @@ def _install_setup_requires(attrs): A minimal version of a distribution for supporting the fetch_build_eggs interface. """ + def __init__(self, attrs): _incl = 'dependency_links', 'setup_requires' - filtered = { - k: attrs[k] - for k in set(_incl) & set(attrs) - } + filtered = {k: attrs[k] for k in set(_incl) & set(attrs)} distutils.core.Distribution.__init__(self, filtered) def finalize_options(self): @@ -178,8 +178,9 @@ class Command(_Command): setattr(self, option, default) return default elif not isinstance(val, str): - raise DistutilsOptionError("'%s' must be a %s (got `%s`)" - % (option, what, val)) + raise DistutilsOptionError( + "'%s' must be a %s (got `%s`)" % (option, what, val) + ) return val def ensure_string_list(self, option): @@ -200,8 +201,8 @@ class Command(_Command): ok = False if not ok: raise DistutilsOptionError( - "'%s' must be a list of strings (got %r)" - % (option, val)) + "'%s' must be a list of strings (got %r)" % (option, val) + ) def reinitialize_command(self, command, reinit_subcommands=0, **kw): cmd = _Command.reinitialize_command(self, command, reinit_subcommands) diff --git a/setuptools/_distutils/_msvccompiler.py b/setuptools/_distutils/_msvccompiler.py index e9af4cf5..b7a06082 100644 --- a/setuptools/_distutils/_msvccompiler.py +++ b/setuptools/_distutils/_msvccompiler.py @@ -248,7 +248,7 @@ class MSVCCompiler(CCompiler) : # Future releases of Python 3.x will include all past # versions of vcruntime*.dll for compatibility. self.compile_options = [ - '/nologo', '/Ox', '/W3', '/GL', '/DNDEBUG', '/MD' + '/nologo', '/O2', '/W3', '/GL', '/DNDEBUG', '/MD' ] self.compile_options_debug = [ diff --git a/setuptools/_distutils/ccompiler.py b/setuptools/_distutils/ccompiler.py index 57bb94e8..48d160d2 100644 --- a/setuptools/_distutils/ccompiler.py +++ b/setuptools/_distutils/ccompiler.py @@ -392,7 +392,7 @@ class CCompiler: return output_dir, macros, include_dirs def _prep_compile(self, sources, output_dir, depends=None): - """Decide which souce files must be recompiled. + """Decide which source files must be recompiled. Determine the list of object files corresponding to 'sources', and figure out which ones really need to be recompiled. @@ -792,6 +792,8 @@ int main (int argc, char **argv) { objects = self.compile([fname], include_dirs=include_dirs) except CompileError: return False + finally: + os.remove(fname) try: self.link_executable(objects, "a.out", @@ -799,6 +801,11 @@ int main (int argc, char **argv) { library_dirs=library_dirs) except (LinkError, TypeError): return False + else: + os.remove("a.out") + finally: + for fn in objects: + os.remove(fn) return True def find_library_file (self, dirs, lib, debug=0): diff --git a/setuptools/_distutils/command/build.py b/setuptools/_distutils/command/build.py index a86df0bc..4355a632 100644 --- a/setuptools/_distutils/command/build.py +++ b/setuptools/_distutils/command/build.py @@ -102,7 +102,7 @@ class build(Command): # particular module distribution -- if user didn't supply it, pick # one of 'build_purelib' or 'build_platlib'. if self.build_lib is None: - if self.distribution.ext_modules: + if self.distribution.has_ext_modules(): self.build_lib = self.build_platlib else: self.build_lib = self.build_purelib diff --git a/setuptools/_distutils/command/build_ext.py b/setuptools/_distutils/command/build_ext.py index bbb34833..22628baf 100644 --- a/setuptools/_distutils/command/build_ext.py +++ b/setuptools/_distutils/command/build_ext.py @@ -220,7 +220,7 @@ class build_ext(Command): # For extensions under Cygwin, Python's library directory must be # appended to library_dirs if sys.platform[:6] == 'cygwin': - if sys.executable.startswith(os.path.join(sys.exec_prefix, "bin")): + if not sysconfig.python_build: # building third party extensions self.library_dirs.append(os.path.join(sys.prefix, "lib", "python" + get_python_version(), @@ -690,13 +690,15 @@ class build_ext(Command): provided, "PyInit_" + module_name. Only relevant on Windows, where the .pyd file (DLL) must export the module "PyInit_" function. """ - suffix = '_' + ext.name.split('.')[-1] + name = ext.name.split('.')[-1] try: # Unicode module name support as defined in PEP-489 # https://www.python.org/dev/peps/pep-0489/#export-hook-name - suffix.encode('ascii') + name.encode('ascii') except UnicodeEncodeError: - suffix = 'U' + suffix.encode('punycode').replace(b'-', b'_').decode('ascii') + suffix = 'U_' + name.encode('punycode').replace(b'-', b'_').decode('ascii') + else: + suffix = "_" + name initfunc_name = "PyInit" + suffix if initfunc_name not in ext.export_symbols: diff --git a/setuptools/_distutils/command/build_py.py b/setuptools/_distutils/command/build_py.py index edc2171c..7ef9bcef 100644 --- a/setuptools/_distutils/command/build_py.py +++ b/setuptools/_distutils/command/build_py.py @@ -9,7 +9,7 @@ import glob from distutils.core import Command from distutils.errors import * -from distutils.util import convert_path, Mixin2to3 +from distutils.util import convert_path from distutils import log class build_py (Command): @@ -390,27 +390,3 @@ class build_py (Command): if self.optimize > 0: byte_compile(files, optimize=self.optimize, force=self.force, prefix=prefix, dry_run=self.dry_run) - -class build_py_2to3(build_py, Mixin2to3): - def run(self): - self.updated_files = [] - - # Base class code - if self.py_modules: - self.build_modules() - if self.packages: - self.build_packages() - self.build_package_data() - - # 2to3 - self.run_2to3(self.updated_files) - - # Remaining base class code - self.byte_compile(self.get_outputs(include_bytecode=0)) - - def build_module(self, module, module_file, package): - res = build_py.build_module(self, module, module_file, package) - if res[1]: - # file was copied - self.updated_files.append(res[0]) - return res diff --git a/setuptools/_distutils/command/build_scripts.py b/setuptools/_distutils/command/build_scripts.py index ccc70e64..e3312cf0 100644 --- a/setuptools/_distutils/command/build_scripts.py +++ b/setuptools/_distutils/command/build_scripts.py @@ -7,7 +7,7 @@ from stat import ST_MODE from distutils import sysconfig from distutils.core import Command from distutils.dep_util import newer -from distutils.util import convert_path, Mixin2to3 +from distutils.util import convert_path from distutils import log import tokenize @@ -150,11 +150,3 @@ class build_scripts(Command): os.chmod(file, newmode) # XXX should we modify self.outfiles? return outfiles, updated_files - -class build_scripts_2to3(build_scripts, Mixin2to3): - - def copy_scripts(self): - outfiles, updated_files = build_scripts.copy_scripts(self) - if not self.dry_run: - self.run_2to3(updated_files) - return outfiles, updated_files diff --git a/setuptools/_distutils/command/install.py b/setuptools/_distutils/command/install.py index 13feeb89..866e2d59 100644 --- a/setuptools/_distutils/command/install.py +++ b/setuptools/_distutils/command/install.py @@ -348,7 +348,7 @@ class install(Command): # module distribution is pure or not. Of course, if the user # already specified install_lib, use their selection. if self.install_lib is None: - if self.distribution.ext_modules: # has extensions: non-pure + if self.distribution.has_ext_modules(): # has extensions: non-pure self.install_lib = self.install_platlib else: self.install_lib = self.install_purelib @@ -470,6 +470,7 @@ class install(Command): """Sets the install directories by applying the install schemes.""" # it's the caller's problem if they supply a bad name! if (hasattr(sys, 'pypy_version_info') and + sys.version_info < (3, 8) and not name.endswith(('_user', '_home'))): if os.name == 'nt': name = 'pypy_nt' diff --git a/setuptools/_distutils/cygwinccompiler.py b/setuptools/_distutils/cygwinccompiler.py index 66c12dd3..f1c38e39 100644 --- a/setuptools/_distutils/cygwinccompiler.py +++ b/setuptools/_distutils/cygwinccompiler.py @@ -44,6 +44,8 @@ cygwin in no-cygwin mode). # (ld supports -shared) # * mingw gcc 3.2/ld 2.13 works # (ld supports -shared) +# * llvm-mingw with Clang 11 works +# (lld supports -shared) import os import sys @@ -109,41 +111,46 @@ class CygwinCCompiler(UnixCCompiler): "Compiling may fail because of undefined preprocessor macros." % details) - self.gcc_version, self.ld_version, self.dllwrap_version = \ - get_versions() - self.debug_print(self.compiler_type + ": gcc %s, ld %s, dllwrap %s\n" % - (self.gcc_version, - self.ld_version, - self.dllwrap_version) ) - - # ld_version >= "2.10.90" and < "2.13" should also be able to use - # gcc -mdll instead of dllwrap - # Older dllwraps had own version numbers, newer ones use the - # same as the rest of binutils ( also ld ) - # dllwrap 2.10.90 is buggy - if self.ld_version >= "2.10.90": - self.linker_dll = "gcc" - else: - self.linker_dll = "dllwrap" + self.cc = os.environ.get('CC', 'gcc') + self.cxx = os.environ.get('CXX', 'g++') + + if ('gcc' in self.cc): # Start gcc workaround + self.gcc_version, self.ld_version, self.dllwrap_version = \ + get_versions() + self.debug_print(self.compiler_type + ": gcc %s, ld %s, dllwrap %s\n" % + (self.gcc_version, + self.ld_version, + self.dllwrap_version) ) + + # ld_version >= "2.10.90" and < "2.13" should also be able to use + # gcc -mdll instead of dllwrap + # Older dllwraps had own version numbers, newer ones use the + # same as the rest of binutils ( also ld ) + # dllwrap 2.10.90 is buggy + if self.ld_version >= "2.10.90": + self.linker_dll = self.cc + else: + self.linker_dll = "dllwrap" - # ld_version >= "2.13" support -shared so use it instead of - # -mdll -static - if self.ld_version >= "2.13": + # ld_version >= "2.13" support -shared so use it instead of + # -mdll -static + if self.ld_version >= "2.13": + shared_option = "-shared" + else: + shared_option = "-mdll -static" + else: # Assume linker is up to date + self.linker_dll = self.cc shared_option = "-shared" - else: - shared_option = "-mdll -static" - # Hard-code GCC because that's what this is all about. - # XXX optimization, warnings etc. should be customizable. - self.set_executables(compiler='gcc -mcygwin -O -Wall', - compiler_so='gcc -mcygwin -mdll -O -Wall', - compiler_cxx='g++ -mcygwin -O -Wall', - linker_exe='gcc -mcygwin', + self.set_executables(compiler='%s -mcygwin -O -Wall' % self.cc, + compiler_so='%s -mcygwin -mdll -O -Wall' % self.cc, + compiler_cxx='%s -mcygwin -O -Wall' % self.cxx, + linker_exe='%s -mcygwin' % self.cc, linker_so=('%s -mcygwin %s' % (self.linker_dll, shared_option))) # cygwin and mingw32 need different sets of libraries - if self.gcc_version == "2.91.57": + if ('gcc' in self.cc and self.gcc_version == "2.91.57"): # cygwin shouldn't need msvcrt, but without the dlls will crash # (gcc version 2.91.57) -- perhaps something about initialization self.dll_libraries=["msvcrt"] @@ -281,26 +288,26 @@ class Mingw32CCompiler(CygwinCCompiler): # ld_version >= "2.13" support -shared so use it instead of # -mdll -static - if self.ld_version >= "2.13": - shared_option = "-shared" - else: + if ('gcc' in self.cc and self.ld_version < "2.13"): shared_option = "-mdll -static" + else: + shared_option = "-shared" # A real mingw32 doesn't need to specify a different entry point, # but cygwin 2.91.57 in no-cygwin-mode needs it. - if self.gcc_version <= "2.91.57": + if ('gcc' in self.cc and self.gcc_version <= "2.91.57"): entry_point = '--entry _DllMain@12' else: entry_point = '' - if is_cygwingcc(): + if is_cygwincc(self.cc): raise CCompilerError( 'Cygwin gcc cannot be used with --compiler=mingw32') - self.set_executables(compiler='gcc -O -Wall', - compiler_so='gcc -mdll -O -Wall', - compiler_cxx='g++ -O -Wall', - linker_exe='gcc', + self.set_executables(compiler='%s -O -Wall' % self.cc, + compiler_so='%s -mdll -O -Wall' % self.cc, + compiler_cxx='%s -O -Wall' % self.cxx, + linker_exe='%s' % self.cc, linker_so='%s %s %s' % (self.linker_dll, shared_option, entry_point)) @@ -351,6 +358,10 @@ def check_config_h(): if "GCC" in sys.version: return CONFIG_H_OK, "sys.version mentions 'GCC'" + # Clang would also work + if "Clang" in sys.version: + return CONFIG_H_OK, "sys.version mentions 'Clang'" + # let's see if __GNUC__ is mentioned in python.h fn = sysconfig.get_config_h_filename() try: @@ -397,7 +408,7 @@ def get_versions(): commands = ['gcc -dumpversion', 'ld -v', 'dllwrap --version'] return tuple([_find_exe_version(cmd) for cmd in commands]) -def is_cygwingcc(): - '''Try to determine if the gcc that would be used is from cygwin.''' - out_string = check_output(['gcc', '-dumpmachine']) +def is_cygwincc(cc): + '''Try to determine if the compiler that would be used is from cygwin.''' + out_string = check_output([cc, '-dumpmachine']) return out_string.strip().endswith(b'cygwin') diff --git a/setuptools/_distutils/filelist.py b/setuptools/_distutils/filelist.py index c92d5fdb..82a77384 100644 --- a/setuptools/_distutils/filelist.py +++ b/setuptools/_distutils/filelist.py @@ -4,13 +4,16 @@ Provides the FileList class, used for poking about the filesystem and building lists of files. """ -import os, re +import os +import re import fnmatch import functools + from distutils.util import convert_path from distutils.errors import DistutilsTemplateError, DistutilsInternalError from distutils import log + class FileList: """A list of files built by on exploring the filesystem and filtered by applying various patterns to what we find there. @@ -46,7 +49,7 @@ class FileList: if DEBUG: print(msg) - # -- List-like methods --------------------------------------------- + # Collection methods def append(self, item): self.files.append(item) @@ -61,8 +64,7 @@ class FileList: for sort_tuple in sortable_files: self.files.append(os.path.join(*sort_tuple)) - - # -- Other miscellaneous utility methods --------------------------- + # Other miscellaneous utility methods def remove_duplicates(self): # Assumes list has been sorted! @@ -70,8 +72,7 @@ class FileList: if self.files[i] == self.files[i - 1]: del self.files[i] - - # -- "File template" methods --------------------------------------- + # "File template" methods def _parse_template_line(self, line): words = line.split() @@ -146,9 +147,11 @@ class FileList: (dir, ' '.join(patterns))) for pattern in patterns: if not self.include_pattern(pattern, prefix=dir): - log.warn(("warning: no files found matching '%s' " - "under directory '%s'"), - pattern, dir) + msg = ( + "warning: no files found matching '%s' " + "under directory '%s'" + ) + log.warn(msg, pattern, dir) elif action == 'recursive-exclude': self.debug_print("recursive-exclude %s %s" % @@ -174,8 +177,7 @@ class FileList: raise DistutilsInternalError( "this cannot happen: invalid action '%s'" % action) - - # -- Filtering/selection methods ----------------------------------- + # Filtering/selection methods def include_pattern(self, pattern, anchor=1, prefix=None, is_regex=0): """Select strings (presumably filenames) from 'self.files' that @@ -219,9 +221,8 @@ class FileList: files_found = True return files_found - - def exclude_pattern (self, pattern, - anchor=1, prefix=None, is_regex=0): + def exclude_pattern( + self, pattern, anchor=1, prefix=None, is_regex=0): """Remove strings (presumably filenames) from 'files' that match 'pattern'. Other parameters are the same as for 'include_pattern()', above. @@ -240,21 +241,47 @@ class FileList: return files_found -# ---------------------------------------------------------------------- # Utility functions def _find_all_simple(path): """ Find all files under 'path' """ + all_unique = _UniqueDirs.filter(os.walk(path, followlinks=True)) results = ( os.path.join(base, file) - for base, dirs, files in os.walk(path, followlinks=True) + for base, dirs, files in all_unique for file in files ) return filter(os.path.isfile, results) +class _UniqueDirs(set): + """ + Exclude previously-seen dirs from walk results, + avoiding infinite recursion. + Ref https://bugs.python.org/issue44497. + """ + def __call__(self, walk_item): + """ + Given an item from an os.walk result, determine + if the item represents a unique dir for this instance + and if not, prevent further traversal. + """ + base, dirs, files = walk_item + stat = os.stat(base) + candidate = stat.st_dev, stat.st_ino + found = candidate in self + if found: + del dirs[:] + self.add(candidate) + return not found + + @classmethod + def filter(cls, items): + return filter(cls(), items) + + def findall(dir=os.curdir): """ Find all files under 'dir' and return the list of full filenames. @@ -319,7 +346,8 @@ def translate_pattern(pattern, anchor=1, prefix=None, is_regex=0): if os.sep == '\\': sep = r'\\' pattern_re = pattern_re[len(start): len(pattern_re) - len(end)] - pattern_re = r'%s\A%s%s.*%s%s' % (start, prefix_re, sep, pattern_re, end) + pattern_re = r'%s\A%s%s.*%s%s' % ( + start, prefix_re, sep, pattern_re, end) else: # no prefix -- respect anchor flag if anchor: pattern_re = r'%s\A%s' % (start, pattern_re[len(start):]) diff --git a/setuptools/_distutils/msvc9compiler.py b/setuptools/_distutils/msvc9compiler.py index 6934e964..a1b3b02f 100644 --- a/setuptools/_distutils/msvc9compiler.py +++ b/setuptools/_distutils/msvc9compiler.py @@ -399,13 +399,13 @@ class MSVCCompiler(CCompiler) : self.preprocess_options = None if self.__arch == "x86": - self.compile_options = [ '/nologo', '/Ox', '/MD', '/W3', + self.compile_options = [ '/nologo', '/O2', '/MD', '/W3', '/DNDEBUG'] self.compile_options_debug = ['/nologo', '/Od', '/MDd', '/W3', '/Z7', '/D_DEBUG'] else: # Win64 - self.compile_options = [ '/nologo', '/Ox', '/MD', '/W3', '/GS-' , + self.compile_options = [ '/nologo', '/O2', '/MD', '/W3', '/GS-' , '/DNDEBUG'] self.compile_options_debug = ['/nologo', '/Od', '/MDd', '/W3', '/GS-', '/Z7', '/D_DEBUG'] diff --git a/setuptools/_distutils/msvccompiler.py b/setuptools/_distutils/msvccompiler.py index d5857cb1..2d447b85 100644 --- a/setuptools/_distutils/msvccompiler.py +++ b/setuptools/_distutils/msvccompiler.py @@ -283,13 +283,13 @@ class MSVCCompiler(CCompiler) : self.preprocess_options = None if self.__arch == "Intel": - self.compile_options = [ '/nologo', '/Ox', '/MD', '/W3', '/GX' , + self.compile_options = [ '/nologo', '/O2', '/MD', '/W3', '/GX' , '/DNDEBUG'] self.compile_options_debug = ['/nologo', '/Od', '/MDd', '/W3', '/GX', '/Z7', '/D_DEBUG'] else: # Win64 - self.compile_options = [ '/nologo', '/Ox', '/MD', '/W3', '/GS-' , + self.compile_options = [ '/nologo', '/O2', '/MD', '/W3', '/GS-' , '/DNDEBUG'] self.compile_options_debug = ['/nologo', '/Od', '/MDd', '/W3', '/GS-', '/Z7', '/D_DEBUG'] diff --git a/setuptools/_distutils/spawn.py b/setuptools/_distutils/spawn.py index fc592d4a..6e1c89f1 100644 --- a/setuptools/_distutils/spawn.py +++ b/setuptools/_distutils/spawn.py @@ -15,11 +15,6 @@ from distutils.debug import DEBUG from distutils import log -if sys.platform == 'darwin': - _cfg_target = None - _cfg_target_split = None - - def spawn(cmd, search_path=1, verbose=0, dry_run=0, env=None): """Run another program, specified as a command list 'cmd', in a new process. @@ -40,7 +35,7 @@ def spawn(cmd, search_path=1, verbose=0, dry_run=0, env=None): # in, protect our %-formatting code against horrible death cmd = list(cmd) - log.info(' '.join(cmd)) + log.info(subprocess.list2cmdline(cmd)) if dry_run: return @@ -52,24 +47,10 @@ def spawn(cmd, search_path=1, verbose=0, dry_run=0, env=None): env = env if env is not None else dict(os.environ) if sys.platform == 'darwin': - global _cfg_target, _cfg_target_split - if _cfg_target is None: - from distutils import sysconfig - _cfg_target = sysconfig.get_config_var( - 'MACOSX_DEPLOYMENT_TARGET') or '' - if _cfg_target: - _cfg_target_split = [int(x) for x in _cfg_target.split('.')] - if _cfg_target: - # ensure that the deployment target of build process is not less - # than that used when the interpreter was built. This ensures - # extension modules are built with correct compatibility values - cur_target = os.environ.get('MACOSX_DEPLOYMENT_TARGET', _cfg_target) - if _cfg_target_split > [int(x) for x in cur_target.split('.')]: - my_msg = ('$MACOSX_DEPLOYMENT_TARGET mismatch: ' - 'now "%s" but "%s" during configure' - % (cur_target, _cfg_target)) - raise DistutilsPlatformError(my_msg) - env.update(MACOSX_DEPLOYMENT_TARGET=cur_target) + from distutils.util import MACOSX_VERSION_VAR, get_macosx_target_ver + macosx_target_ver = get_macosx_target_ver() + if macosx_target_ver: + env[MACOSX_VERSION_VAR] = macosx_target_ver try: proc = subprocess.Popen(cmd, env=env) diff --git a/setuptools/_distutils/sysconfig.py b/setuptools/_distutils/sysconfig.py index 879b6981..8832b3ec 100644 --- a/setuptools/_distutils/sysconfig.py +++ b/setuptools/_distutils/sysconfig.py @@ -99,9 +99,9 @@ def get_python_inc(plat_specific=0, prefix=None): """ if prefix is None: prefix = plat_specific and BASE_EXEC_PREFIX or BASE_PREFIX - if IS_PYPY: - return os.path.join(prefix, 'include') - elif os.name == "posix": + if os.name == "posix": + if IS_PYPY and sys.version_info < (3, 8): + return os.path.join(prefix, 'include') if python_build: # Assume the executable is in the build directory. The # pyconfig.h file should be in the same directory. Since @@ -113,7 +113,8 @@ def get_python_inc(plat_specific=0, prefix=None): else: incdir = os.path.join(get_config_var('srcdir'), 'Include') return os.path.normpath(incdir) - python_dir = 'python' + get_python_version() + build_flags + implementation = 'pypy' if IS_PYPY else 'python' + python_dir = implementation + get_python_version() + build_flags return os.path.join(prefix, "include", python_dir) elif os.name == "nt": if python_build: @@ -142,7 +143,8 @@ def get_python_lib(plat_specific=0, standard_lib=0, prefix=None): If 'prefix' is supplied, use it instead of sys.base_prefix or sys.base_exec_prefix -- i.e., ignore 'plat_specific'. """ - if IS_PYPY: + + if IS_PYPY and sys.version_info < (3, 8): # PyPy-specific schema if prefix is None: prefix = PREFIX @@ -164,8 +166,9 @@ def get_python_lib(plat_specific=0, standard_lib=0, prefix=None): else: # Pure Python libdir = "lib" + implementation = 'pypy' if IS_PYPY else 'python' libpython = os.path.join(prefix, libdir, - "python" + get_python_version()) + implementation + get_python_version()) if standard_lib: return libpython else: @@ -211,10 +214,9 @@ def customize_compiler(compiler): if 'CC' in os.environ: newcc = os.environ['CC'] - if (sys.platform == 'darwin' - and 'LDSHARED' not in os.environ + if('LDSHARED' not in os.environ and ldshared.startswith(cc)): - # On OS X, if CC is overridden, use that as the default + # If CC is overridden, use that as the default # command for LDSHARED as well ldshared = newcc + ldshared[len(cc):] cc = newcc @@ -252,6 +254,9 @@ def customize_compiler(compiler): linker_exe=cc, archiver=archiver) + if 'RANLIB' in os.environ and compiler.executables.get('ranlib', None): + compiler.set_executables(ranlib=os.environ['RANLIB']) + compiler.shared_lib_extension = shlib_suffix diff --git a/setuptools/_distutils/tests/test_build_ext.py b/setuptools/_distutils/tests/test_build_ext.py index 5a72458c..85ecf4b7 100644 --- a/setuptools/_distutils/tests/test_build_ext.py +++ b/setuptools/_distutils/tests/test_build_ext.py @@ -316,7 +316,7 @@ class BuildExtTestCase(TempdirManager, self.assertRegex(cmd.get_ext_filename(modules[0].name), r'foo(_d)?\..*') self.assertRegex(cmd.get_ext_filename(modules[1].name), r'föö(_d)?\..*') self.assertEqual(cmd.get_export_symbols(modules[0]), ['PyInit_foo']) - self.assertEqual(cmd.get_export_symbols(modules[1]), ['PyInitU_f_gkaa']) + self.assertEqual(cmd.get_export_symbols(modules[1]), ['PyInitU_f_1gaa']) def test_compiler_option(self): # cmd.compiler is an option and diff --git a/setuptools/_distutils/tests/test_filelist.py b/setuptools/_distutils/tests/test_filelist.py index d8e4b39f..9ec507b5 100644 --- a/setuptools/_distutils/tests/test_filelist.py +++ b/setuptools/_distutils/tests/test_filelist.py @@ -331,6 +331,16 @@ class FindAllTestCase(unittest.TestCase): expected = [file1] self.assertEqual(filelist.findall(temp_dir), expected) + @os_helper.skip_unless_symlink + def test_symlink_loop(self): + with os_helper.temp_dir() as temp_dir: + link = os.path.join(temp_dir, 'link-to-parent') + content = os.path.join(temp_dir, 'somefile') + os_helper.create_empty_file(content) + os.symlink('.', link) + files = filelist.findall(temp_dir) + assert len(files) == 1 + def test_suite(): return unittest.TestSuite([ diff --git a/setuptools/_distutils/tests/test_sysconfig.py b/setuptools/_distutils/tests/test_sysconfig.py index c7571942..80cd1599 100644 --- a/setuptools/_distutils/tests/test_sysconfig.py +++ b/setuptools/_distutils/tests/test_sysconfig.py @@ -9,6 +9,7 @@ import unittest from distutils import sysconfig from distutils.ccompiler import get_default_compiler +from distutils.unixccompiler import UnixCCompiler from distutils.tests import support from test.support import run_unittest, swap_item @@ -84,9 +85,14 @@ class SysconfigTestCase(support.EnvironGuard, unittest.TestCase): # make sure AR gets caught class compiler: compiler_type = 'unix' + executables = UnixCCompiler.executables + + def __init__(self): + self.exes = {} def set_executables(self, **kw): - self.exes = kw + for k, v in kw.items(): + self.exes[k] = v sysconfig_vars = { 'AR': 'sc_ar', @@ -125,6 +131,7 @@ class SysconfigTestCase(support.EnvironGuard, unittest.TestCase): os.environ['ARFLAGS'] = '--env-arflags' os.environ['CFLAGS'] = '--env-cflags' os.environ['CPPFLAGS'] = '--env-cppflags' + os.environ['RANLIB'] = 'env_ranlib' comp = self.customize_compiler() self.assertEqual(comp.exes['archiver'], @@ -145,6 +152,12 @@ class SysconfigTestCase(support.EnvironGuard, unittest.TestCase): ' --env-cppflags')) self.assertEqual(comp.shared_lib_extension, 'sc_shutil_suffix') + if sys.platform == "darwin": + self.assertEqual(comp.exes['ranlib'], + 'env_ranlib') + else: + self.assertTrue('ranlib' not in comp.exes) + del os.environ['AR'] del os.environ['CC'] del os.environ['CPP'] @@ -154,6 +167,7 @@ class SysconfigTestCase(support.EnvironGuard, unittest.TestCase): del os.environ['ARFLAGS'] del os.environ['CFLAGS'] del os.environ['CPPFLAGS'] + del os.environ['RANLIB'] comp = self.customize_compiler() self.assertEqual(comp.exes['archiver'], @@ -171,6 +185,7 @@ class SysconfigTestCase(support.EnvironGuard, unittest.TestCase): self.assertEqual(comp.exes['linker_so'], 'sc_ldshared') self.assertEqual(comp.shared_lib_extension, 'sc_shutil_suffix') + self.assertTrue('ranlib' not in comp.exes) def test_parse_makefile_base(self): self.makefile = TESTFN diff --git a/setuptools/_distutils/tests/test_unixccompiler.py b/setuptools/_distutils/tests/test_unixccompiler.py index 1828ba1a..ee2fe99c 100644 --- a/setuptools/_distutils/tests/test_unixccompiler.py +++ b/setuptools/_distutils/tests/test_unixccompiler.py @@ -1,4 +1,5 @@ """Tests for distutils.unixccompiler.""" +import os import sys import unittest from test.support import run_unittest @@ -6,7 +7,9 @@ from test.support import run_unittest from .py38compat import EnvironmentVarGuard from distutils import sysconfig +from distutils.errors import DistutilsPlatformError from distutils.unixccompiler import UnixCCompiler +from distutils.util import _clear_cached_macosx_ver class UnixCCompilerTestCase(unittest.TestCase): @@ -26,18 +29,90 @@ class UnixCCompilerTestCase(unittest.TestCase): @unittest.skipIf(sys.platform == 'win32', "can't test on Windows") def test_runtime_libdir_option(self): - # Issue#5900 + # Issue #5900; GitHub Issue #37 # # Ensure RUNPATH is added to extension modules with RPATH if # GNU ld is used # darwin sys.platform = 'darwin' - self.assertEqual(self.cc.rpath_foo(), '-L/foo') + darwin_ver_var = 'MACOSX_DEPLOYMENT_TARGET' + darwin_rpath_flag = '-Wl,-rpath,/foo' + darwin_lib_flag = '-L/foo' + + # (macOS version from syscfg, macOS version from env var) -> flag + # Version value of None generates two tests: as None and as empty string + # Expected flag value of None means an mismatch exception is expected + darwin_test_cases = [ + ((None , None ), darwin_lib_flag), + ((None , '11' ), darwin_rpath_flag), + (('10' , None ), darwin_lib_flag), + (('10.3' , None ), darwin_lib_flag), + (('10.3.1', None ), darwin_lib_flag), + (('10.5' , None ), darwin_rpath_flag), + (('10.5.1', None ), darwin_rpath_flag), + (('10.3' , '10.3' ), darwin_lib_flag), + (('10.3' , '10.5' ), darwin_rpath_flag), + (('10.5' , '10.3' ), darwin_lib_flag), + (('10.5' , '11' ), darwin_rpath_flag), + (('10.4' , '10' ), None), + ] + + def make_darwin_gcv(syscfg_macosx_ver): + def gcv(var): + if var == darwin_ver_var: + return syscfg_macosx_ver + return "xxx" + return gcv + + def do_darwin_test(syscfg_macosx_ver, env_macosx_ver, expected_flag): + env = os.environ + msg = "macOS version = (sysconfig=%r, env=%r)" % \ + (syscfg_macosx_ver, env_macosx_ver) + + # Save + old_gcv = sysconfig.get_config_var + old_env_macosx_ver = env.get(darwin_ver_var) + + # Setup environment + _clear_cached_macosx_ver() + sysconfig.get_config_var = make_darwin_gcv(syscfg_macosx_ver) + if env_macosx_ver is not None: + env[darwin_ver_var] = env_macosx_ver + elif darwin_ver_var in env: + env.pop(darwin_ver_var) + + # Run the test + if expected_flag is not None: + self.assertEqual(self.cc.rpath_foo(), expected_flag, msg=msg) + else: + with self.assertRaisesRegex(DistutilsPlatformError, + darwin_ver_var + r' mismatch', msg=msg): + self.cc.rpath_foo() + + # Restore + if old_env_macosx_ver is not None: + env[darwin_ver_var] = old_env_macosx_ver + elif darwin_ver_var in env: + env.pop(darwin_ver_var) + sysconfig.get_config_var = old_gcv + _clear_cached_macosx_ver() + + for macosx_vers, expected_flag in darwin_test_cases: + syscfg_macosx_ver, env_macosx_ver = macosx_vers + do_darwin_test(syscfg_macosx_ver, env_macosx_ver, expected_flag) + # Bonus test cases with None interpreted as empty string + if syscfg_macosx_ver is None: + do_darwin_test("", env_macosx_ver, expected_flag) + if env_macosx_ver is None: + do_darwin_test(syscfg_macosx_ver, "", expected_flag) + if syscfg_macosx_ver is None and env_macosx_ver is None: + do_darwin_test("", "", expected_flag) + + old_gcv = sysconfig.get_config_var # hp-ux sys.platform = 'hp-ux' - old_gcv = sysconfig.get_config_var def gcv(v): return 'xxx' sysconfig.get_config_var = gcv @@ -65,6 +140,14 @@ class UnixCCompilerTestCase(unittest.TestCase): sysconfig.get_config_var = gcv self.assertEqual(self.cc.rpath_foo(), '-Wl,--enable-new-dtags,-R/foo') + def gcv(v): + if v == 'CC': + return 'gcc -pthread -B /bar' + elif v == 'GNULD': + return 'yes' + sysconfig.get_config_var = gcv + self.assertEqual(self.cc.rpath_foo(), '-Wl,--enable-new-dtags,-R/foo') + # GCC non-GNULD sys.platform = 'bar' def gcv(v): @@ -94,7 +177,7 @@ class UnixCCompilerTestCase(unittest.TestCase): elif v == 'GNULD': return 'yes' sysconfig.get_config_var = gcv - self.assertEqual(self.cc.rpath_foo(), '-R/foo') + self.assertEqual(self.cc.rpath_foo(), '-Wl,--enable-new-dtags,-R/foo') # non-GCC non-GNULD sys.platform = 'bar' @@ -104,10 +187,10 @@ class UnixCCompilerTestCase(unittest.TestCase): elif v == 'GNULD': return 'no' sysconfig.get_config_var = gcv - self.assertEqual(self.cc.rpath_foo(), '-R/foo') + self.assertEqual(self.cc.rpath_foo(), '-Wl,-R/foo') - @unittest.skipUnless(sys.platform == 'darwin', 'test only relevant for OS X') - def test_osx_cc_overrides_ldshared(self): + @unittest.skipIf(sys.platform == 'win32', "can't test on Windows") + def test_cc_overrides_ldshared(self): # Issue #18080: # ensure that setting CC env variable also changes default linker def gcv(v): @@ -127,8 +210,8 @@ class UnixCCompilerTestCase(unittest.TestCase): sysconfig.customize_compiler(self.cc) self.assertEqual(self.cc.linker_so[0], 'my_cc') - @unittest.skipUnless(sys.platform == 'darwin', 'test only relevant for OS X') - def test_osx_explicit_ldshared(self): + @unittest.skipIf(sys.platform == 'win32', "can't test on Windows") + def test_explicit_ldshared(self): # Issue #18080: # ensure that setting CC env variable does not change # explicit LDSHARED setting for linker diff --git a/setuptools/_distutils/unixccompiler.py b/setuptools/_distutils/unixccompiler.py index 4d7a6de7..a07e5988 100644 --- a/setuptools/_distutils/unixccompiler.py +++ b/setuptools/_distutils/unixccompiler.py @@ -13,7 +13,7 @@ the "typical" Unix-style command-line C compiler: * link shared library handled by 'cc -shared' """ -import os, sys, re +import os, sys, re, shlex from distutils import sysconfig from distutils.dep_util import newer @@ -231,33 +231,30 @@ class UnixCCompiler(CCompiler): # this time, there's no way to determine this information from # the configuration data stored in the Python installation, so # we use this hack. - compiler = os.path.basename(sysconfig.get_config_var("CC")) + compiler = os.path.basename(shlex.split(sysconfig.get_config_var("CC"))[0]) if sys.platform[:6] == "darwin": - # MacOSX's linker doesn't understand the -R flag at all - return "-L" + dir + from distutils.util import get_macosx_target_ver, split_version + macosx_target_ver = get_macosx_target_ver() + if macosx_target_ver and split_version(macosx_target_ver) >= [10, 5]: + return "-Wl,-rpath," + dir + else: # no support for -rpath on earlier macOS versions + return "-L" + dir elif sys.platform[:7] == "freebsd": return "-Wl,-rpath=" + dir elif sys.platform[:5] == "hp-ux": if self._is_gcc(compiler): return ["-Wl,+s", "-L" + dir] return ["+s", "-L" + dir] + + # For all compilers, `-Wl` is the presumed way to + # pass a compiler option to the linker and `-R` is + # the way to pass an RPATH. + if sysconfig.get_config_var("GNULD") == "yes": + # GNU ld needs an extra option to get a RUNPATH + # instead of just an RPATH. + return "-Wl,--enable-new-dtags,-R" + dir else: - if self._is_gcc(compiler): - # gcc on non-GNU systems does not need -Wl, but can - # use it anyway. Since distutils has always passed in - # -Wl whenever gcc was used in the past it is probably - # safest to keep doing so. - if sysconfig.get_config_var("GNULD") == "yes": - # GNU ld needs an extra option to get a RUNPATH - # instead of just an RPATH. - return "-Wl,--enable-new-dtags,-R" + dir - else: - return "-Wl,-R" + dir - else: - # No idea how --enable-new-dtags would be passed on to - # ld if this system was using GNU ld. Don't know if a - # system like this even exists. - return "-R" + dir + return "-Wl,-R" + dir def library_option(self, lib): return "-l" + lib diff --git a/setuptools/_distutils/util.py b/setuptools/_distutils/util.py index f5aca794..afc23c4e 100644 --- a/setuptools/_distutils/util.py +++ b/setuptools/_distutils/util.py @@ -103,11 +103,66 @@ def get_platform(): 'x86' : 'win32', 'x64' : 'win-amd64', 'arm' : 'win-arm32', + 'arm64': 'win-arm64', } return TARGET_TO_PLAT.get(os.environ.get('VSCMD_ARG_TGT_ARCH')) or get_host_platform() else: return get_host_platform() + +if sys.platform == 'darwin': + _syscfg_macosx_ver = None # cache the version pulled from sysconfig +MACOSX_VERSION_VAR = 'MACOSX_DEPLOYMENT_TARGET' + +def _clear_cached_macosx_ver(): + """For testing only. Do not call.""" + global _syscfg_macosx_ver + _syscfg_macosx_ver = None + +def get_macosx_target_ver_from_syscfg(): + """Get the version of macOS latched in the Python interpreter configuration. + Returns the version as a string or None if can't obtain one. Cached.""" + global _syscfg_macosx_ver + if _syscfg_macosx_ver is None: + from distutils import sysconfig + ver = sysconfig.get_config_var(MACOSX_VERSION_VAR) or '' + if ver: + _syscfg_macosx_ver = ver + return _syscfg_macosx_ver + +def get_macosx_target_ver(): + """Return the version of macOS for which we are building. + + The target version defaults to the version in sysconfig latched at time + the Python interpreter was built, unless overridden by an environment + variable. If neither source has a value, then None is returned""" + + syscfg_ver = get_macosx_target_ver_from_syscfg() + env_ver = os.environ.get(MACOSX_VERSION_VAR) + + if env_ver: + # Validate overridden version against sysconfig version, if have both. + # Ensure that the deployment target of the build process is not less + # than 10.3 if the interpreter was built for 10.3 or later. This + # ensures extension modules are built with correct compatibility + # values, specifically LDSHARED which can use + # '-undefined dynamic_lookup' which only works on >= 10.3. + if syscfg_ver and split_version(syscfg_ver) >= [10, 3] and \ + split_version(env_ver) < [10, 3]: + my_msg = ('$' + MACOSX_VERSION_VAR + ' mismatch: ' + 'now "%s" but "%s" during configure; ' + 'must use 10.3 or later' + % (env_ver, syscfg_ver)) + raise DistutilsPlatformError(my_msg) + return env_ver + return syscfg_ver + + +def split_version(s): + """Convert a dot-separated string into a list of numbers for comparisons""" + return [int(n) for n in s.split('.')] + + def convert_path (pathname): """Return 'pathname' as a name that will work on the native filesystem, i.e. split it on '/' and put it back together again using the current @@ -478,84 +533,3 @@ def rfc822_escape (header): lines = header.split('\n') sep = '\n' + 8 * ' ' return sep.join(lines) - -# 2to3 support - -def run_2to3(files, fixer_names=None, options=None, explicit=None): - """Invoke 2to3 on a list of Python files. - The files should all come from the build area, as the - modification is done in-place. To reduce the build time, - only files modified since the last invocation of this - function should be passed in the files argument.""" - - if not files: - return - - # Make this class local, to delay import of 2to3 - from lib2to3.refactor import RefactoringTool, get_fixers_from_package - class DistutilsRefactoringTool(RefactoringTool): - def log_error(self, msg, *args, **kw): - log.error(msg, *args) - - def log_message(self, msg, *args): - log.info(msg, *args) - - def log_debug(self, msg, *args): - log.debug(msg, *args) - - if fixer_names is None: - fixer_names = get_fixers_from_package('lib2to3.fixes') - r = DistutilsRefactoringTool(fixer_names, options=options) - r.refactor(files, write=True) - -def copydir_run_2to3(src, dest, template=None, fixer_names=None, - options=None, explicit=None): - """Recursively copy a directory, only copying new and changed files, - running run_2to3 over all newly copied Python modules afterward. - - If you give a template string, it's parsed like a MANIFEST.in. - """ - from distutils.dir_util import mkpath - from distutils.file_util import copy_file - from distutils.filelist import FileList - filelist = FileList() - curdir = os.getcwd() - os.chdir(src) - try: - filelist.findall() - finally: - os.chdir(curdir) - filelist.files[:] = filelist.allfiles - if template: - for line in template.splitlines(): - line = line.strip() - if not line: continue - filelist.process_template_line(line) - copied = [] - for filename in filelist.files: - outname = os.path.join(dest, filename) - mkpath(os.path.dirname(outname)) - res = copy_file(os.path.join(src, filename), outname, update=1) - if res[1]: copied.append(outname) - run_2to3([fn for fn in copied if fn.lower().endswith('.py')], - fixer_names=fixer_names, options=options, explicit=explicit) - return copied - -class Mixin2to3: - '''Mixin class for commands that run 2to3. - To configure 2to3, setup scripts may either change - the class variables, or inherit from individual commands - to override how 2to3 is invoked.''' - - # provide list of fixers to run; - # defaults to all from lib2to3.fixers - fixer_names = None - - # options dictionary - options = None - - # list of fixers to invoke even though they are marked as explicit - explicit = None - - def run_2to3(self, files): - return run_2to3(files, self.fixer_names, self.options, self.explicit) diff --git a/setuptools/_imp.py b/setuptools/_imp.py index 451e45a8..47efd792 100644 --- a/setuptools/_imp.py +++ b/setuptools/_imp.py @@ -41,12 +41,12 @@ def find_module(module, paths=None): spec.loader, importlib.machinery.FrozenImporter): kind = PY_FROZEN path = None # imp compabilty - suffix = mode = '' # imp compability + suffix = mode = '' # imp compatibility elif spec.origin == 'built-in' or static and issubclass( spec.loader, importlib.machinery.BuiltinImporter): kind = C_BUILTIN path = None # imp compabilty - suffix = mode = '' # imp compability + suffix = mode = '' # imp compatibility elif spec.has_location: path = spec.origin suffix = os.path.splitext(path)[1] diff --git a/setuptools/_vendor/more_itertools/__init__.py b/setuptools/_vendor/more_itertools/__init__.py new file mode 100644 index 00000000..19a169fc --- /dev/null +++ b/setuptools/_vendor/more_itertools/__init__.py @@ -0,0 +1,4 @@ +from .more import * # noqa +from .recipes import * # noqa + +__version__ = '8.8.0' diff --git a/setuptools/_vendor/more_itertools/__init__.pyi b/setuptools/_vendor/more_itertools/__init__.pyi new file mode 100644 index 00000000..96f6e36c --- /dev/null +++ b/setuptools/_vendor/more_itertools/__init__.pyi @@ -0,0 +1,2 @@ +from .more import * +from .recipes import * diff --git a/setuptools/_vendor/more_itertools/more.py b/setuptools/_vendor/more_itertools/more.py new file mode 100644 index 00000000..0f7d282a --- /dev/null +++ b/setuptools/_vendor/more_itertools/more.py @@ -0,0 +1,3825 @@ +import warnings + +from collections import Counter, defaultdict, deque, abc +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from functools import partial, reduce, wraps +from heapq import merge, heapify, heapreplace, heappop +from itertools import ( + chain, + compress, + count, + cycle, + dropwhile, + groupby, + islice, + repeat, + starmap, + takewhile, + tee, + zip_longest, +) +from math import exp, factorial, floor, log +from queue import Empty, Queue +from random import random, randrange, uniform +from operator import itemgetter, mul, sub, gt, lt +from sys import hexversion, maxsize +from time import monotonic + +from .recipes import ( + consume, + flatten, + pairwise, + powerset, + take, + unique_everseen, +) + +__all__ = [ + 'AbortThread', + 'adjacent', + 'always_iterable', + 'always_reversible', + 'bucket', + 'callback_iter', + 'chunked', + 'circular_shifts', + 'collapse', + 'collate', + 'consecutive_groups', + 'consumer', + 'countable', + 'count_cycle', + 'mark_ends', + 'difference', + 'distinct_combinations', + 'distinct_permutations', + 'distribute', + 'divide', + 'exactly_n', + 'filter_except', + 'first', + 'groupby_transform', + 'ilen', + 'interleave_longest', + 'interleave', + 'intersperse', + 'islice_extended', + 'iterate', + 'ichunked', + 'is_sorted', + 'last', + 'locate', + 'lstrip', + 'make_decorator', + 'map_except', + 'map_reduce', + 'nth_or_last', + 'nth_permutation', + 'nth_product', + 'numeric_range', + 'one', + 'only', + 'padded', + 'partitions', + 'set_partitions', + 'peekable', + 'repeat_last', + 'replace', + 'rlocate', + 'rstrip', + 'run_length', + 'sample', + 'seekable', + 'SequenceView', + 'side_effect', + 'sliced', + 'sort_together', + 'split_at', + 'split_after', + 'split_before', + 'split_when', + 'split_into', + 'spy', + 'stagger', + 'strip', + 'substrings', + 'substrings_indexes', + 'time_limited', + 'unique_to_each', + 'unzip', + 'windowed', + 'with_iter', + 'UnequalIterablesError', + 'zip_equal', + 'zip_offset', + 'windowed_complete', + 'all_unique', + 'value_chain', + 'product_index', + 'combination_index', + 'permutation_index', +] + +_marker = object() + + +def chunked(iterable, n, strict=False): + """Break *iterable* into lists of length *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) + [[1, 2, 3], [4, 5, 6]] + + By the default, the last yielded list will have fewer than *n* elements + if the length of *iterable* is not divisible by *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) + [[1, 2, 3], [4, 5, 6], [7, 8]] + + To use a fill-in value instead, see the :func:`grouper` recipe. + + If the length of *iterable* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + list is yielded. + + """ + iterator = iter(partial(take, n, iter(iterable)), []) + if strict: + + def ret(): + for chunk in iterator: + if len(chunk) != n: + raise ValueError('iterable is not divisible by n.') + yield chunk + + return iter(ret()) + else: + return iterator + + +def first(iterable, default=_marker): + """Return the first item of *iterable*, or *default* if *iterable* is + empty. + + >>> first([0, 1, 2, 3]) + 0 + >>> first([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + + :func:`first` is useful when you have a generator of expensive-to-retrieve + values and want any arbitrary one. It is marginally shorter than + ``next(iter(iterable), default)``. + + """ + try: + return next(iter(iterable)) + except StopIteration as e: + if default is _marker: + raise ValueError( + 'first() was called on an empty iterable, and no ' + 'default value was provided.' + ) from e + return default + + +def last(iterable, default=_marker): + """Return the last item of *iterable*, or *default* if *iterable* is + empty. + + >>> last([0, 1, 2, 3]) + 3 + >>> last([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + try: + if isinstance(iterable, Sequence): + return iterable[-1] + # Work around https://bugs.python.org/issue38525 + elif hasattr(iterable, '__reversed__') and (hexversion != 0x030800F0): + return next(reversed(iterable)) + else: + return deque(iterable, maxlen=1)[-1] + except (IndexError, TypeError, StopIteration): + if default is _marker: + raise ValueError( + 'last() was called on an empty iterable, and no default was ' + 'provided.' + ) + return default + + +def nth_or_last(iterable, n, default=_marker): + """Return the nth or the last item of *iterable*, + or *default* if *iterable* is empty. + + >>> nth_or_last([0, 1, 2, 3], 2) + 2 + >>> nth_or_last([0, 1], 2) + 1 + >>> nth_or_last([], 0, 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + return last(islice(iterable, n + 1), default=default) + + +class peekable: + """Wrap an iterator to allow lookahead and prepending elements. + + Call :meth:`peek` on the result to get the value that will be returned + by :func:`next`. This won't advance the iterator: + + >>> p = peekable(['a', 'b']) + >>> p.peek() + 'a' + >>> next(p) + 'a' + + Pass :meth:`peek` a default value to return that instead of raising + ``StopIteration`` when the iterator is exhausted. + + >>> p = peekable([]) + >>> p.peek('hi') + 'hi' + + peekables also offer a :meth:`prepend` method, which "inserts" items + at the head of the iterable: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> p.peek() + 11 + >>> list(p) + [11, 12, 1, 2, 3] + + peekables can be indexed. Index 0 is the item that will be returned by + :func:`next`, index 1 is the item after that, and so on: + The values up to the given index will be cached. + + >>> p = peekable(['a', 'b', 'c', 'd']) + >>> p[0] + 'a' + >>> p[1] + 'b' + >>> next(p) + 'a' + + Negative indexes are supported, but be aware that they will cache the + remaining items in the source iterator, which may require significant + storage. + + To check whether a peekable is exhausted, check its truth value: + + >>> p = peekable(['a', 'b']) + >>> if p: # peekable has items + ... list(p) + ['a', 'b'] + >>> if not p: # peekable is exhausted + ... list(p) + [] + + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self._cache = deque() + + def __iter__(self): + return self + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + """Return the item that will be next returned from ``next()``. + + Return ``default`` if there are no items left. If ``default`` is not + provided, raise ``StopIteration``. + + """ + if not self._cache: + try: + self._cache.append(next(self._it)) + except StopIteration: + if default is _marker: + raise + return default + return self._cache[0] + + def prepend(self, *items): + """Stack up items to be the next ones returned from ``next()`` or + ``self.peek()``. The items will be returned in + first in, first out order:: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> list(p) + [11, 12, 1, 2, 3] + + It is possible, by prepending items, to "resurrect" a peekable that + previously raised ``StopIteration``. + + >>> p = peekable([]) + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + >>> p.prepend(1) + >>> next(p) + 1 + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + + """ + self._cache.extendleft(reversed(items)) + + def __next__(self): + if self._cache: + return self._cache.popleft() + + return next(self._it) + + def _get_slice(self, index): + # Normalize the slice's arguments + step = 1 if (index.step is None) else index.step + if step > 0: + start = 0 if (index.start is None) else index.start + stop = maxsize if (index.stop is None) else index.stop + elif step < 0: + start = -1 if (index.start is None) else index.start + stop = (-maxsize - 1) if (index.stop is None) else index.stop + else: + raise ValueError('slice step cannot be zero') + + # If either the start or stop index is negative, we'll need to cache + # the rest of the iterable in order to slice from the right side. + if (start < 0) or (stop < 0): + self._cache.extend(self._it) + # Otherwise we'll need to find the rightmost index and cache to that + # point. + else: + n = min(max(start, stop) + 1, maxsize) + cache_len = len(self._cache) + if n >= cache_len: + self._cache.extend(islice(self._it, n - cache_len)) + + return list(self._cache)[index] + + def __getitem__(self, index): + if isinstance(index, slice): + return self._get_slice(index) + + cache_len = len(self._cache) + if index < 0: + self._cache.extend(self._it) + elif index >= cache_len: + self._cache.extend(islice(self._it, index + 1 - cache_len)) + + return self._cache[index] + + +def collate(*iterables, **kwargs): + """Return a sorted merge of the items from each of several already-sorted + *iterables*. + + >>> list(collate('ACDZ', 'AZ', 'JKL')) + ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] + + Works lazily, keeping only the next value from each iterable in memory. Use + :func:`collate` to, for example, perform a n-way mergesort of items that + don't fit in memory. + + If a *key* function is specified, the iterables will be sorted according + to its result: + + >>> key = lambda s: int(s) # Sort by numeric value, not by string + >>> list(collate(['1', '10'], ['2', '11'], key=key)) + ['1', '2', '10', '11'] + + + If the *iterables* are sorted in descending order, set *reverse* to + ``True``: + + >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True)) + [5, 4, 3, 2, 1, 0] + + If the elements of the passed-in iterables are out of order, you might get + unexpected results. + + On Python 3.5+, this function is an alias for :func:`heapq.merge`. + + """ + warnings.warn( + "collate is no longer part of more_itertools, use heapq.merge", + DeprecationWarning, + ) + return merge(*iterables, **kwargs) + + +def consumer(func): + """Decorator that automatically advances a PEP-342-style "reverse iterator" + to its first yield point so you don't have to call ``next()`` on it + manually. + + >>> @consumer + ... def tally(): + ... i = 0 + ... while True: + ... print('Thing number %s is %s.' % (i, (yield))) + ... i += 1 + ... + >>> t = tally() + >>> t.send('red') + Thing number 0 is red. + >>> t.send('fish') + Thing number 1 is fish. + + Without the decorator, you would have to call ``next(t)`` before + ``t.send()`` could be used. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + gen = func(*args, **kwargs) + next(gen) + return gen + + return wrapper + + +def ilen(iterable): + """Return the number of items in *iterable*. + + >>> ilen(x for x in range(1000000) if x % 3 == 0) + 333334 + + This consumes the iterable, so handle with care. + + """ + # This approach was selected because benchmarks showed it's likely the + # fastest of the known implementations at the time of writing. + # See GitHub tracker: #236, #230. + counter = count() + deque(zip(iterable, counter), maxlen=0) + return next(counter) + + +def iterate(func, start): + """Return ``start``, ``func(start)``, ``func(func(start))``, ... + + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + """ + while True: + yield start + start = func(start) + + +def with_iter(context_manager): + """Wrap an iterable in a ``with`` statement, so it closes once exhausted. + + For example, this will close the file when the iterator is exhausted:: + + upper_lines = (line.upper() for line in with_iter(open('foo'))) + + Any context manager which returns an iterable is a candidate for + ``with_iter``. + + """ + with context_manager as iterable: + yield from iterable + + +def one(iterable, too_short=None, too_long=None): + """Return the first item from *iterable*, which is expected to contain only + that item. Raise an exception if *iterable* is empty or has more than one + item. + + :func:`one` is useful for ensuring that an iterable contains only one item. + For example, it can be used to retrieve the result of a database query + that is expected to return a single row. + + If *iterable* is empty, ``ValueError`` will be raised. You may specify a + different exception with the *too_short* keyword: + + >>> it = [] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_short = IndexError('too few items') + >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + IndexError: too few items + + Similarly, if *iterable* contains more than one item, ``ValueError`` will + be raised. You may specify a different exception with the *too_long* + keyword: + + >>> it = ['too', 'many'] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 'too', + 'many', and perhaps more. + >>> too_long = RuntimeError + >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + Note that :func:`one` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check iterable + contents less destructively. + + """ + it = iter(iterable) + + try: + first_value = next(it) + except StopIteration as e: + raise ( + too_short or ValueError('too few items in iterable (expected 1)') + ) from e + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def distinct_permutations(iterable, r=None): + """Yield successive distinct permutations of the elements in *iterable*. + + >>> sorted(distinct_permutations([1, 0, 1])) + [(0, 1, 1), (1, 0, 1), (1, 1, 0)] + + Equivalent to ``set(permutations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + Duplicate permutations arise when there are duplicated elements in the + input iterable. The number of items returned is + `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of + items input, and each `x_i` is the count of a distinct item in the input + sequence. + + If *r* is given, only the *r*-length permutations are yielded. + + >>> sorted(distinct_permutations([1, 0, 1], r=2)) + [(0, 1), (1, 0), (1, 1)] + >>> sorted(distinct_permutations(range(3), r=2)) + [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + + """ + # Algorithm: https://w.wiki/Qai + def _full(A): + while True: + # Yield the permutation we have + yield tuple(A) + + # Find the largest index i such that A[i] < A[i + 1] + for i in range(size - 2, -1, -1): + if A[i] < A[i + 1]: + break + # If no such index exists, this permutation is the last one + else: + return + + # Find the largest index j greater than j such that A[i] < A[j] + for j in range(size - 1, i, -1): + if A[i] < A[j]: + break + + # Swap the value of A[i] with that of A[j], then reverse the + # sequence from A[i + 1] to form the new permutation + A[i], A[j] = A[j], A[i] + A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1] + + # Algorithm: modified from the above + def _partial(A, r): + # Split A into the first r items and the last r items + head, tail = A[:r], A[r:] + right_head_indexes = range(r - 1, -1, -1) + left_tail_indexes = range(len(tail)) + + while True: + # Yield the permutation we have + yield tuple(head) + + # Starting from the right, find the first index of the head with + # value smaller than the maximum value of the tail - call it i. + pivot = tail[-1] + for i in right_head_indexes: + if head[i] < pivot: + break + pivot = head[i] + else: + return + + # Starting from the left, find the first value of the tail + # with a value greater than head[i] and swap. + for j in left_tail_indexes: + if tail[j] > head[i]: + head[i], tail[j] = tail[j], head[i] + break + # If we didn't find one, start from the right and find the first + # index of the head with a value greater than head[i] and swap. + else: + for j in right_head_indexes: + if head[j] > head[i]: + head[i], head[j] = head[j], head[i] + break + + # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] + tail += head[: i - r : -1] # head[i + 1:][::-1] + i += 1 + head[i:], tail[:] = tail[: r - i], tail[r - i :] + + items = sorted(iterable) + + size = len(items) + if r is None: + r = size + + if 0 < r <= size: + return _full(items) if (r == size) else _partial(items, r) + + return iter(() if r else ((),)) + + +def intersperse(e, iterable, n=1): + """Intersperse filler element *e* among the items in *iterable*, leaving + *n* items between each filler element. + + >>> list(intersperse('!', [1, 2, 3, 4, 5])) + [1, '!', 2, '!', 3, '!', 4, '!', 5] + + >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2)) + [1, 2, None, 3, 4, None, 5] + + """ + if n == 0: + raise ValueError('n must be > 0') + elif n == 1: + # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2... + return islice(interleave(repeat(e), iterable), 1, None) + else: + # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... + # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]... + # flatten(...) -> x_0, x_1, e, x_2, x_3... + filler = repeat([e]) + chunks = chunked(iterable, n) + return flatten(islice(interleave(filler, chunks), 1, None)) + + +def unique_to_each(*iterables): + """Return the elements from each of the input iterables that aren't in the + other input iterables. + + For example, suppose you have a set of packages, each with a set of + dependencies:: + + {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}} + + If you remove one package, which dependencies can also be removed? + + If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not + associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for + ``pkg_2``, and ``D`` is only needed for ``pkg_3``:: + + >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'}) + [['A'], ['C'], ['D']] + + If there are duplicates in one input iterable that aren't in the others + they will be duplicated in the output. Input order is preserved:: + + >>> unique_to_each("mississippi", "missouri") + [['p', 'p'], ['o', 'u', 'r']] + + It is assumed that the elements of each iterable are hashable. + + """ + pool = [list(it) for it in iterables] + counts = Counter(chain.from_iterable(map(set, pool))) + uniques = {element for element in counts if counts[element] == 1} + return [list(filter(uniques.__contains__, it)) for it in pool] + + +def windowed(seq, n, fillvalue=None, step=1): + """Return a sliding window of width *n* over the given iterable. + + >>> all_windows = windowed([1, 2, 3, 4, 5], 3) + >>> list(all_windows) + [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + + When the window is larger than the iterable, *fillvalue* is used in place + of missing values: + + >>> list(windowed([1, 2, 3], 4)) + [(1, 2, 3, None)] + + Each window will advance in increments of *step*: + + >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) + [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + + To slide into the iterable's items, use :func:`chain` to add filler items + to the left: + + >>> iterable = [1, 2, 3, 4] + >>> n = 3 + >>> padding = [None] * (n - 1) + >>> list(windowed(chain(padding, iterable), 3)) + [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] + """ + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + yield tuple() + return + if step < 1: + raise ValueError('step must be >= 1') + + window = deque(maxlen=n) + i = n + for _ in map(window.append, seq): + i -= 1 + if not i: + i = step + yield tuple(window) + + size = len(window) + if size < n: + yield tuple(chain(window, repeat(fillvalue, n - size))) + elif 0 < i < min(step, n): + window += (fillvalue,) * i + yield tuple(window) + + +def substrings(iterable): + """Yield all of the substrings of *iterable*. + + >>> [''.join(s) for s in substrings('more')] + ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more'] + + Note that non-string iterables can also be subdivided. + + >>> list(substrings([0, 1, 2])) + [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)] + + """ + # The length-1 substrings + seq = [] + for item in iter(iterable): + seq.append(item) + yield (item,) + seq = tuple(seq) + item_count = len(seq) + + # And the rest + for n in range(2, item_count + 1): + for i in range(item_count - n + 1): + yield seq[i : i + n] + + +def substrings_indexes(seq, reverse=False): + """Yield all substrings and their positions in *seq* + + The items yielded will be a tuple of the form ``(substr, i, j)``, where + ``substr == seq[i:j]``. + + This function only works for iterables that support slicing, such as + ``str`` objects. + + >>> for item in substrings_indexes('more'): + ... print(item) + ('m', 0, 1) + ('o', 1, 2) + ('r', 2, 3) + ('e', 3, 4) + ('mo', 0, 2) + ('or', 1, 3) + ('re', 2, 4) + ('mor', 0, 3) + ('ore', 1, 4) + ('more', 0, 4) + + Set *reverse* to ``True`` to yield the same items in the opposite order. + + + """ + r = range(1, len(seq) + 1) + if reverse: + r = reversed(r) + return ( + (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) + ) + + +class bucket: + """Wrap *iterable* and return an object that buckets it iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + yield from self._cache.keys() + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) + + +def spy(iterable, n=1): + """Return a 2-tuple with a list containing the first *n* elements of + *iterable*, and an iterator with the same items as *iterable*. + This allows you to "look ahead" at the items in the iterable without + advancing it. + + There is one item in the list by default: + + >>> iterable = 'abcdefg' + >>> head, iterable = spy(iterable) + >>> head + ['a'] + >>> list(iterable) + ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + + You may use unpacking to retrieve items instead of lists: + + >>> (head,), iterable = spy('abcdefg') + >>> head + 'a' + >>> (first, second), iterable = spy('abcdefg', 2) + >>> first + 'a' + >>> second + 'b' + + The number of items requested can be larger than the number of items in + the iterable: + + >>> iterable = [1, 2, 3, 4, 5] + >>> head, iterable = spy(iterable, 10) + >>> head + [1, 2, 3, 4, 5] + >>> list(iterable) + [1, 2, 3, 4, 5] + + """ + it = iter(iterable) + head = take(n, it) + + return head.copy(), chain(head, it) + + +def interleave(*iterables): + """Return a new iterable yielding from each iterable in turn, + until the shortest is exhausted. + + >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7] + + For a version that doesn't terminate after the shortest iterable is + exhausted, see :func:`interleave_longest`. + + """ + return chain.from_iterable(zip(*iterables)) + + +def interleave_longest(*iterables): + """Return a new iterable yielding from each iterable in turn, + skipping any that are exhausted. + + >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7, 3, 8] + + This function produces the same output as :func:`roundrobin`, but may + perform better for some inputs (in particular when the number of iterables + is large). + + """ + i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker)) + return (x for x in i if x is not _marker) + + +def collapse(iterable, base_type=None, levels=None): + """Flatten an iterable with multiple levels of nesting (e.g., a list of + lists of tuples) into non-iterable types. + + >>> iterable = [(1, 2), ([3, 4], [[5], [6]])] + >>> list(collapse(iterable)) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and + will not be collapsed. + + To avoid collapsing other types, specify *base_type*: + + >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] + >>> list(collapse(iterable, base_type=tuple)) + ['ab', ('cd', 'ef'), 'gh', 'ij'] + + Specify *levels* to stop flattening after a certain level: + + >>> iterable = [('a', ['b']), ('c', ['d'])] + >>> list(collapse(iterable)) # Fully flattened + ['a', 'b', 'c', 'd'] + >>> list(collapse(iterable, levels=1)) # Only one level flattened + ['a', ['b'], 'c', ['d']] + + """ + + def walk(node, level): + if ( + ((levels is not None) and (level > levels)) + or isinstance(node, (str, bytes)) + or ((base_type is not None) and isinstance(node, base_type)) + ): + yield node + return + + try: + tree = iter(node) + except TypeError: + yield node + return + else: + for child in tree: + yield from walk(child, level + 1) + + yield from walk(iterable, 0) + + +def side_effect(func, iterable, chunk_size=None, before=None, after=None): + """Invoke *func* on each item in *iterable* (or on each *chunk_size* group + of items) before yielding the item. + + `func` must be a function that takes a single argument. Its return value + will be discarded. + + *before* and *after* are optional functions that take no arguments. They + will be executed before iteration starts and after it ends, respectively. + + `side_effect` can be used for logging, updating progress bars, or anything + that is not functionally "pure." + + Emitting a status message: + + >>> from more_itertools import consume + >>> func = lambda item: print('Received {}'.format(item)) + >>> consume(side_effect(func, range(2))) + Received 0 + Received 1 + + Operating on chunks of items: + + >>> pair_sums = [] + >>> func = lambda chunk: pair_sums.append(sum(chunk)) + >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2)) + [0, 1, 2, 3, 4, 5] + >>> list(pair_sums) + [1, 5, 9] + + Writing to a file-like object: + + >>> from io import StringIO + >>> from more_itertools import consume + >>> f = StringIO() + >>> func = lambda x: print(x, file=f) + >>> before = lambda: print(u'HEADER', file=f) + >>> after = f.close + >>> it = [u'a', u'b', u'c'] + >>> consume(side_effect(func, it, before=before, after=after)) + >>> f.closed + True + + """ + try: + if before is not None: + before() + + if chunk_size is None: + for item in iterable: + func(item) + yield item + else: + for chunk in chunked(iterable, chunk_size): + func(chunk) + yield from chunk + finally: + if after is not None: + after() + + +def sliced(seq, n, strict=False): + """Yield slices of length *n* from the sequence *seq*. + + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] + + By the default, the last yielded slice will have fewer than *n* elements + if the length of *seq* is not divisible by *n*: + + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + If the length of *seq* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + slice is yielded. + + This function will only work for iterables that support slicing. + For non-sliceable iterables, see :func:`chunked`. + + """ + iterator = takewhile(len, (seq[i : i + n] for i in count(0, n))) + if strict: + + def ret(): + for _slice in iterator: + if len(_slice) != n: + raise ValueError("seq is not divisible by n.") + yield _slice + + return iter(ret()) + else: + return iterator + + +def split_at(iterable, pred, maxsplit=-1, keep_separator=False): + """Yield lists of items from *iterable*, where each list is delimited by + an item where callable *pred* returns ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b')) + [['a'], ['c', 'd', 'c'], ['a']] + + >>> list(split_at(range(10), lambda n: n % 2 == 1)) + [[0], [2], [4], [6], [8], []] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) + [[0], [2], [4, 5, 6, 7, 8, 9]] + + By default, the delimiting items are not included in the output. + The include them, set *keep_separator* to ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) + [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item): + yield buf + if keep_separator: + yield [item] + if maxsplit == 1: + yield list(it) + return + buf = [] + maxsplit -= 1 + else: + buf.append(item) + yield buf + + +def split_before(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends just before + an item for which callable *pred* returns ``True``: + + >>> list(split_before('OneTwo', lambda s: s.isupper())) + [['O', 'n', 'e'], ['T', 'w', 'o']] + + >>> list(split_before(range(10), lambda n: n % 3 == 0)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item) and buf: + yield buf + if maxsplit == 1: + yield [item] + list(it) + return + buf = [] + maxsplit -= 1 + buf.append(item) + if buf: + yield buf + + +def split_after(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends with an + item where callable *pred* returns ``True``: + + >>> list(split_after('one1two2', lambda s: s.isdigit())) + [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']] + + >>> list(split_after(range(10), lambda n: n % 3 == 0)) + [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + buf.append(item) + if pred(item) and buf: + yield buf + if maxsplit == 1: + yield list(it) + return + buf = [] + maxsplit -= 1 + if buf: + yield buf + + +def split_when(iterable, pred, maxsplit=-1): + """Split *iterable* into pieces based on the output of *pred*. + *pred* should be a function that takes successive pairs of items and + returns ``True`` if the iterable should be split in between them. + + For example, to find runs of increasing numbers, split the iterable when + element ``i`` is larger than element ``i + 1``: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) + [[1, 2, 3, 3], [2, 5], [2, 4], [2]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], + ... lambda x, y: x > y, maxsplit=2)) + [[1, 2, 3, 3], [2, 5], [2, 4, 2]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + it = iter(iterable) + try: + cur_item = next(it) + except StopIteration: + return + + buf = [cur_item] + for next_item in it: + if pred(cur_item, next_item): + yield buf + if maxsplit == 1: + yield [next_item] + list(it) + return + buf = [] + maxsplit -= 1 + + buf.append(next_item) + cur_item = next_item + + yield buf + + +def split_into(iterable, sizes): + """Yield a list of sequential items from *iterable* of length 'n' for each + integer 'n' in *sizes*. + + >>> list(split_into([1,2,3,4,5,6], [1,2,3])) + [[1], [2, 3], [4, 5, 6]] + + If the sum of *sizes* is smaller than the length of *iterable*, then the + remaining items of *iterable* will not be returned. + + >>> list(split_into([1,2,3,4,5,6], [2,3])) + [[1, 2], [3, 4, 5]] + + If the sum of *sizes* is larger than the length of *iterable*, fewer items + will be returned in the iteration that overruns *iterable* and further + lists will be empty: + + >>> list(split_into([1,2,3,4], [1,2,3,4])) + [[1], [2, 3], [4], []] + + When a ``None`` object is encountered in *sizes*, the returned list will + contain items up to the end of *iterable* the same way that itertools.slice + does: + + >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None])) + [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]] + + :func:`split_into` can be useful for grouping a series of items where the + sizes of the groups are not uniform. An example would be where in a row + from a table, multiple columns represent elements of the same feature + (e.g. a point represented by x,y,z) but, the format is not the same for + all columns. + """ + # convert the iterable argument into an iterator so its contents can + # be consumed by islice in case it is a generator + it = iter(iterable) + + for size in sizes: + if size is None: + yield list(it) + return + else: + yield list(islice(it, size)) + + +def padded(iterable, fillvalue=None, n=None, next_multiple=False): + """Yield the elements from *iterable*, followed by *fillvalue*, such that + at least *n* items are emitted. + + >>> list(padded([1, 2, 3], '?', 5)) + [1, 2, 3, '?', '?'] + + If *next_multiple* is ``True``, *fillvalue* will be emitted until the + number of items emitted is a multiple of *n*:: + + >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True)) + [1, 2, 3, 4, None, None] + + If *n* is ``None``, *fillvalue* will be emitted indefinitely. + + """ + it = iter(iterable) + if n is None: + yield from chain(it, repeat(fillvalue)) + elif n < 1: + raise ValueError('n must be at least 1') + else: + item_count = 0 + for item in it: + yield item + item_count += 1 + + remaining = (n - item_count) % n if next_multiple else n - item_count + for _ in range(remaining): + yield fillvalue + + +def repeat_last(iterable, default=None): + """After the *iterable* is exhausted, keep yielding its last element. + + >>> list(islice(repeat_last(range(3)), 5)) + [0, 1, 2, 2, 2] + + If the iterable is empty, yield *default* forever:: + + >>> list(islice(repeat_last(range(0), 42), 5)) + [42, 42, 42, 42, 42] + + """ + item = _marker + for item in iterable: + yield item + final = default if item is _marker else item + yield from repeat(final) + + +def distribute(n, iterable): + """Distribute the items from *iterable* among *n* smaller iterables. + + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + + If the length of *iterable* is smaller than *n*, then the last returned + iterables will be empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function uses :func:`itertools.tee` and may require significant + storage. If you need the order items in the smaller iterables to match the + original iterable, see :func:`divide`. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + children = tee(iterable, n) + return [islice(it, index, None, n) for index, it in enumerate(children)] + + +def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): + """Yield tuples whose elements are offset from *iterable*. + The amount by which the `i`-th item in each tuple is offset is given by + the `i`-th item in *offsets*. + + >>> list(stagger([0, 1, 2, 3])) + [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + >>> list(stagger(range(8), offsets=(0, 2, 4))) + [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)] + + By default, the sequence will end when the final element of a tuple is the + last item in the iterable. To continue until the first element of a tuple + is the last item in the iterable, set *longest* to ``True``:: + + >>> list(stagger([0, 1, 2, 3], longest=True)) + [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + children = tee(iterable, len(offsets)) + + return zip_offset( + *children, offsets=offsets, longest=longest, fillvalue=fillvalue + ) + + +class UnequalIterablesError(ValueError): + def __init__(self, details=None): + msg = 'Iterables have different lengths' + if details is not None: + msg += (': index 0 has length {}; index {} has length {}').format( + *details + ) + + super().__init__(msg) + + +def _zip_equal_generator(iterables): + for combo in zip_longest(*iterables, fillvalue=_marker): + for val in combo: + if val is _marker: + raise UnequalIterablesError() + yield combo + + +def zip_equal(*iterables): + """``zip`` the input *iterables* together, but raise + ``UnequalIterablesError`` if they aren't all the same length. + + >>> it_1 = range(3) + >>> it_2 = iter('abc') + >>> list(zip_equal(it_1, it_2)) + [(0, 'a'), (1, 'b'), (2, 'c')] + + >>> it_1 = range(3) + >>> it_2 = iter('abcd') + >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + more_itertools.more.UnequalIterablesError: Iterables have different + lengths + + """ + if hexversion >= 0x30A00A6: + warnings.warn( + ( + 'zip_equal will be removed in a future version of ' + 'more-itertools. Use the builtin zip function with ' + 'strict=True instead.' + ), + DeprecationWarning, + ) + # Check whether the iterables are all the same size. + try: + first_size = len(iterables[0]) + for i, it in enumerate(iterables[1:], 1): + size = len(it) + if size != first_size: + break + else: + # If we didn't break out, we can use the built-in zip. + return zip(*iterables) + + # If we did break out, there was a mismatch. + raise UnequalIterablesError(details=(first_size, i, size)) + # If any one of the iterables didn't have a length, start reading + # them until one runs out. + except TypeError: + return _zip_equal_generator(iterables) + + +def zip_offset(*iterables, offsets, longest=False, fillvalue=None): + """``zip`` the input *iterables* together, but offset the `i`-th iterable + by the `i`-th item in *offsets*. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1))) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] + + This can be used as a lightweight alternative to SciPy or pandas to analyze + data sets in which some series have a lead or lag relationship. + + By default, the sequence will end when the shortest iterable is exhausted. + To continue until the longest iterable is exhausted, set *longest* to + ``True``. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True)) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + if len(iterables) != len(offsets): + raise ValueError("Number of iterables and offsets didn't match") + + staggered = [] + for it, n in zip(iterables, offsets): + if n < 0: + staggered.append(chain(repeat(fillvalue, -n), it)) + elif n > 0: + staggered.append(islice(it, n, None)) + else: + staggered.append(it) + + if longest: + return zip_longest(*staggered, fillvalue=fillvalue) + + return zip(*staggered) + + +def sort_together(iterables, key_list=(0,), key=None, reverse=False): + """Return the input iterables sorted together, with *key_list* as the + priority for sorting. All iterables are trimmed to the length of the + shortest one. + + This can be used like the sorting function in a spreadsheet. If each + iterable represents a column of data, the key list determines which + columns are used for sorting. + + By default, all iterables are sorted using the ``0``-th iterable:: + + >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')] + >>> sort_together(iterables) + [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] + + Set a different key list to sort according to another iterable. + Specifying multiple keys dictates how ties are broken:: + + >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] + >>> sort_together(iterables, key_list=(1, 2)) + [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + + To sort by a function of the elements of the iterable, pass a *key* + function. Its arguments are the elements of the iterables corresponding to + the key list:: + + >>> names = ('a', 'b', 'c') + >>> lengths = (1, 2, 3) + >>> widths = (5, 2, 1) + >>> def area(length, width): + ... return length * width + >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area) + [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)] + + Set *reverse* to ``True`` to sort in descending order. + + >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) + [(3, 2, 1), ('a', 'b', 'c')] + + """ + if key is None: + # if there is no key function, the key argument to sorted is an + # itemgetter + key_argument = itemgetter(*key_list) + else: + # if there is a key function, call it with the items at the offsets + # specified by the key function as arguments + key_list = list(key_list) + if len(key_list) == 1: + # if key_list contains a single item, pass the item at that offset + # as the only argument to the key function + key_offset = key_list[0] + key_argument = lambda zipped_items: key(zipped_items[key_offset]) + else: + # if key_list contains multiple items, use itemgetter to return a + # tuple of items, which we pass as *args to the key function + get_key_items = itemgetter(*key_list) + key_argument = lambda zipped_items: key( + *get_key_items(zipped_items) + ) + + return list( + zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse)) + ) + + +def unzip(iterable): + """The inverse of :func:`zip`, this function disaggregates the elements + of the zipped *iterable*. + + The ``i``-th iterable contains the ``i``-th element from each element + of the zipped iterable. The first element is used to to determine the + length of the remaining elements. + + >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> letters, numbers = unzip(iterable) + >>> list(letters) + ['a', 'b', 'c', 'd'] + >>> list(numbers) + [1, 2, 3, 4] + + This is similar to using ``zip(*iterable)``, but it avoids reading + *iterable* into memory. Note, however, that this function uses + :func:`itertools.tee` and thus may require significant storage. + + """ + head, iterable = spy(iter(iterable)) + if not head: + # empty iterable, e.g. zip([], [], []) + return () + # spy returns a one-length iterable as head + head = head[0] + iterables = tee(iterable, len(head)) + + def itemgetter(i): + def getter(obj): + try: + return obj[i] + except IndexError: + # basically if we have an iterable like + # iter([(1, 2, 3), (4, 5), (6,)]) + # the second unzipped iterable would fail at the third tuple + # since it would try to access tup[1] + # same with the third unzipped iterable and the second tuple + # to support these "improperly zipped" iterables, + # we create a custom itemgetter + # which just stops the unzipped iterables + # at first length mismatch + raise StopIteration + + return getter + + return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables)) + + +def divide(n, iterable): + """Divide the elements from *iterable* into *n* parts, maintaining + order. + + >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 2, 3] + >>> list(group_2) + [4, 5, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 2, 3], [4, 5], [6, 7]] + + If the length of the iterable is smaller than n, then the last returned + iterables will be empty: + + >>> children = divide(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function will exhaust the iterable before returning and may require + significant storage. If order is not important, see :func:`distribute`, + which does not first pull the iterable into memory. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + try: + iterable[:0] + except TypeError: + seq = tuple(iterable) + else: + seq = iterable + + q, r = divmod(len(seq), n) + + ret = [] + stop = 0 + for i in range(1, n + 1): + start = stop + stop += q + 1 if i <= r else q + ret.append(iter(seq[start:stop])) + + return ret + + +def always_iterable(obj, base_type=(str, bytes)): + """If *obj* is iterable, return an iterator over its items:: + + >>> obj = (1, 2, 3) + >>> list(always_iterable(obj)) + [1, 2, 3] + + If *obj* is not iterable, return a one-item iterable containing *obj*:: + + >>> obj = 1 + >>> list(always_iterable(obj)) + [1] + + If *obj* is ``None``, return an empty iterable: + + >>> obj = None + >>> list(always_iterable(None)) + [] + + By default, binary and text strings are not considered iterable:: + + >>> obj = 'foo' + >>> list(always_iterable(obj)) + ['foo'] + + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + >>> obj = {'a': 1} + >>> list(always_iterable(obj)) # Iterate over the dict's keys + ['a'] + >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit + [{'a': 1}] + + Set *base_type* to ``None`` to avoid any special handling and treat objects + Python considers iterable as iterable: + + >>> obj = 'foo' + >>> list(always_iterable(obj, base_type=None)) + ['f', 'o', 'o'] + """ + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + +def adjacent(predicate, iterable, distance=1): + """Return an iterable over `(bool, item)` tuples where the `item` is + drawn from *iterable* and the `bool` indicates whether + that item satisfies the *predicate* or is adjacent to an item that does. + + For example, to find whether items are adjacent to a ``3``:: + + >>> list(adjacent(lambda x: x == 3, range(6))) + [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)] + + Set *distance* to change what counts as adjacent. For example, to find + whether items are two places away from a ``3``: + + >>> list(adjacent(lambda x: x == 3, range(6), distance=2)) + [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)] + + This is useful for contextualizing the results of a search function. + For example, a code comparison tool might want to identify lines that + have changed, but also surrounding lines to give the viewer of the diff + context. + + The predicate function will only be called once for each item in the + iterable. + + See also :func:`groupby_transform`, which can be used with this function + to group ranges of items with the same `bool` value. + + """ + # Allow distance=0 mainly for testing that it reproduces results with map() + if distance < 0: + raise ValueError('distance must be at least 0') + + i1, i2 = tee(iterable) + padding = [False] * distance + selected = chain(padding, map(predicate, i1), padding) + adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1)) + return zip(adjacent_to_selected, i2) + + +def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): + """An extension of :func:`itertools.groupby` that can apply transformations + to the grouped data. + + * *keyfunc* is a function computing a key value for each item in *iterable* + * *valuefunc* is a function that transforms the individual items from + *iterable* after grouping + * *reducefunc* is a function that transforms each group of items + + >>> iterable = 'aAAbBBcCC' + >>> keyfunc = lambda k: k.upper() + >>> valuefunc = lambda v: v.lower() + >>> reducefunc = lambda g: ''.join(g) + >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc)) + [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')] + + Each optional argument defaults to an identity function if not specified. + + :func:`groupby_transform` is useful when grouping elements of an iterable + using a separate iterable as the key. To do this, :func:`zip` the iterables + and pass a *keyfunc* that extracts the first element and a *valuefunc* + that extracts the second element:: + + >>> from operator import itemgetter + >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3] + >>> values = 'abcdefghi' + >>> iterable = zip(keys, values) + >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1)) + >>> [(k, ''.join(g)) for k, g in grouper] + [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')] + + Note that the order of items in the iterable is significant. + Only adjacent items are grouped together, so if you don't want any + duplicate groups, you should sort the iterable by the key function. + + """ + ret = groupby(iterable, keyfunc) + if valuefunc: + ret = ((k, map(valuefunc, g)) for k, g in ret) + if reducefunc: + ret = ((k, reducefunc(g)) for k, g in ret) + + return ret + + +class numeric_range(abc.Sequence, abc.Hashable): + """An extension of the built-in ``range()`` function whose arguments can + be any orderable numeric type. + + With only *stop* specified, *start* defaults to ``0`` and *step* + defaults to ``1``. The output items will match the type of *stop*: + + >>> list(numeric_range(3.5)) + [0.0, 1.0, 2.0, 3.0] + + With only *start* and *stop* specified, *step* defaults to ``1``. The + output items will match the type of *start*: + + >>> from decimal import Decimal + >>> start = Decimal('2.1') + >>> stop = Decimal('5.1') + >>> list(numeric_range(start, stop)) + [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')] + + With *start*, *stop*, and *step* specified the output items will match + the type of ``start + step``: + + >>> from fractions import Fraction + >>> start = Fraction(1, 2) # Start at 1/2 + >>> stop = Fraction(5, 2) # End at 5/2 + >>> step = Fraction(1, 2) # Count by 1/2 + >>> list(numeric_range(start, stop, step)) + [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)] + + If *step* is zero, ``ValueError`` is raised. Negative steps are supported: + + >>> list(numeric_range(3, -1, -1.0)) + [3.0, 2.0, 1.0, 0.0] + + Be aware of the limitations of floating point numbers; the representation + of the yielded numbers may be surprising. + + ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* + is a ``datetime.timedelta`` object: + + >>> import datetime + >>> start = datetime.datetime(2019, 1, 1) + >>> stop = datetime.datetime(2019, 1, 3) + >>> step = datetime.timedelta(days=1) + >>> items = iter(numeric_range(start, stop, step)) + >>> next(items) + datetime.datetime(2019, 1, 1, 0, 0) + >>> next(items) + datetime.datetime(2019, 1, 2, 0, 0) + + """ + + _EMPTY_HASH = hash(range(0, 0)) + + def __init__(self, *args): + argc = len(args) + if argc == 1: + (self._stop,) = args + self._start = type(self._stop)(0) + self._step = type(self._stop - self._start)(1) + elif argc == 2: + self._start, self._stop = args + self._step = type(self._stop - self._start)(1) + elif argc == 3: + self._start, self._stop, self._step = args + elif argc == 0: + raise TypeError( + 'numeric_range expected at least ' + '1 argument, got {}'.format(argc) + ) + else: + raise TypeError( + 'numeric_range expected at most ' + '3 arguments, got {}'.format(argc) + ) + + self._zero = type(self._step)(0) + if self._step == self._zero: + raise ValueError('numeric_range() arg 3 must not be zero') + self._growing = self._step > self._zero + self._init_len() + + def __bool__(self): + if self._growing: + return self._start < self._stop + else: + return self._start > self._stop + + def __contains__(self, elem): + if self._growing: + if self._start <= elem < self._stop: + return (elem - self._start) % self._step == self._zero + else: + if self._start >= elem > self._stop: + return (self._start - elem) % (-self._step) == self._zero + + return False + + def __eq__(self, other): + if isinstance(other, numeric_range): + empty_self = not bool(self) + empty_other = not bool(other) + if empty_self or empty_other: + return empty_self and empty_other # True if both empty + else: + return ( + self._start == other._start + and self._step == other._step + and self._get_by_index(-1) == other._get_by_index(-1) + ) + else: + return False + + def __getitem__(self, key): + if isinstance(key, int): + return self._get_by_index(key) + elif isinstance(key, slice): + step = self._step if key.step is None else key.step * self._step + + if key.start is None or key.start <= -self._len: + start = self._start + elif key.start >= self._len: + start = self._stop + else: # -self._len < key.start < self._len + start = self._get_by_index(key.start) + + if key.stop is None or key.stop >= self._len: + stop = self._stop + elif key.stop <= -self._len: + stop = self._start + else: # -self._len < key.stop < self._len + stop = self._get_by_index(key.stop) + + return numeric_range(start, stop, step) + else: + raise TypeError( + 'numeric range indices must be ' + 'integers or slices, not {}'.format(type(key).__name__) + ) + + def __hash__(self): + if self: + return hash((self._start, self._get_by_index(-1), self._step)) + else: + return self._EMPTY_HASH + + def __iter__(self): + values = (self._start + (n * self._step) for n in count()) + if self._growing: + return takewhile(partial(gt, self._stop), values) + else: + return takewhile(partial(lt, self._stop), values) + + def __len__(self): + return self._len + + def _init_len(self): + if self._growing: + start = self._start + stop = self._stop + step = self._step + else: + start = self._stop + stop = self._start + step = -self._step + distance = stop - start + if distance <= self._zero: + self._len = 0 + else: # distance > 0 and step > 0: regular euclidean division + q, r = divmod(distance, step) + self._len = int(q) + int(r != self._zero) + + def __reduce__(self): + return numeric_range, (self._start, self._stop, self._step) + + def __repr__(self): + if self._step == 1: + return "numeric_range({}, {})".format( + repr(self._start), repr(self._stop) + ) + else: + return "numeric_range({}, {}, {})".format( + repr(self._start), repr(self._stop), repr(self._step) + ) + + def __reversed__(self): + return iter( + numeric_range( + self._get_by_index(-1), self._start - self._step, -self._step + ) + ) + + def count(self, value): + return int(value in self) + + def index(self, value): + if self._growing: + if self._start <= value < self._stop: + q, r = divmod(value - self._start, self._step) + if r == self._zero: + return int(q) + else: + if self._start >= value > self._stop: + q, r = divmod(self._start - value, -self._step) + if r == self._zero: + return int(q) + + raise ValueError("{} is not in numeric range".format(value)) + + def _get_by_index(self, i): + if i < 0: + i += self._len + if i < 0 or i >= self._len: + raise IndexError("numeric range object index out of range") + return self._start + i * self._step + + +def count_cycle(iterable, n=None): + """Cycle through the items from *iterable* up to *n* times, yielding + the number of completed cycles along with each item. If *n* is omitted the + process repeats indefinitely. + + >>> list(count_cycle('AB', 3)) + [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')] + + """ + iterable = tuple(iterable) + if not iterable: + return iter(()) + counter = count() if n is None else range(n) + return ((i, item) for i in counter for item in iterable) + + +def mark_ends(iterable): + """Yield 3-tuples of the form ``(is_first, is_last, item)``. + + >>> list(mark_ends('ABC')) + [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')] + + Use this when looping over an iterable to take special action on its first + and/or last items: + + >>> iterable = ['Header', 100, 200, 'Footer'] + >>> total = 0 + >>> for is_first, is_last, item in mark_ends(iterable): + ... if is_first: + ... continue # Skip the header + ... if is_last: + ... continue # Skip the footer + ... total += item + >>> print(total) + 300 + """ + it = iter(iterable) + + try: + b = next(it) + except StopIteration: + return + + try: + for i in count(): + a = b + b = next(it) + yield i == 0, False, a + + except StopIteration: + yield i == 0, True, a + + +def locate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(locate([0, 1, 1, 0, 1, 0, 0])) + [1, 2, 4] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item. + + >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b')) + [1, 3] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(locate(iterable, pred=pred, window_size=3)) + [1, 5, 9] + + Use with :func:`seekable` to find indexes and then retrieve the associated + items: + + >>> from itertools import count + >>> from more_itertools import seekable + >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count()) + >>> it = seekable(source) + >>> pred = lambda x: x > 100 + >>> indexes = locate(it, pred=pred) + >>> i = next(indexes) + >>> it.seek(i) + >>> next(it) + 106 + + """ + if window_size is None: + return compress(count(), map(pred, iterable)) + + if window_size < 1: + raise ValueError('window size must be at least 1') + + it = windowed(iterable, window_size, fillvalue=_marker) + return compress(count(), starmap(pred, it)) + + +def lstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the beginning + for which *pred* returns ``True``. + + For example, to remove a set of items from the start of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(lstrip(iterable, pred)) + [1, 2, None, 3, False, None] + + This function is analogous to to :func:`str.lstrip`, and is essentially + an wrapper for :func:`itertools.dropwhile`. + + """ + return dropwhile(pred, iterable) + + +def rstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the end + for which *pred* returns ``True``. + + For example, to remove a set of items from the end of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(rstrip(iterable, pred)) + [None, False, None, 1, 2, None, 3] + + This function is analogous to :func:`str.rstrip`. + + """ + cache = [] + cache_append = cache.append + cache_clear = cache.clear + for x in iterable: + if pred(x): + cache_append(x) + else: + yield from cache + cache_clear() + yield x + + +def strip(iterable, pred): + """Yield the items from *iterable*, but strip any from the + beginning and end for which *pred* returns ``True``. + + For example, to remove a set of items from both ends of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(strip(iterable, pred)) + [1, 2, None, 3] + + This function is analogous to :func:`str.strip`. + + """ + return rstrip(lstrip(iterable, pred), pred) + + +class islice_extended: + """An extension of :func:`itertools.islice` that supports negative values + for *stop*, *start*, and *step*. + + >>> iterable = iter('abcdefgh') + >>> list(islice_extended(iterable, -4, -1)) + ['e', 'f', 'g'] + + Slices with negative values require some caching of *iterable*, but this + function takes care to minimize the amount of memory required. + + For example, you can use a negative step with an infinite iterator: + + >>> from itertools import count + >>> list(islice_extended(count(), 110, 99, -2)) + [110, 108, 106, 104, 102, 100] + + You can also use slice notation directly: + + >>> iterable = map(str, count()) + >>> it = islice_extended(iterable)[10:20:2] + >>> list(it) + ['10', '12', '14', '16', '18'] + + """ + + def __init__(self, iterable, *args): + it = iter(iterable) + if args: + self._iterable = _islice_helper(it, slice(*args)) + else: + self._iterable = it + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterable) + + def __getitem__(self, key): + if isinstance(key, slice): + return islice_extended(_islice_helper(self._iterable, key)) + + raise TypeError('islice_extended.__getitem__ argument must be a slice') + + +def _islice_helper(it, s): + start = s.start + stop = s.stop + if s.step == 0: + raise ValueError('step argument must be a non-zero integer or None.') + step = s.step or 1 + + if step > 0: + start = 0 if (start is None) else start + + if start < 0: + # Consume all but the last -start items + cache = deque(enumerate(it, 1), maxlen=-start) + len_iter = cache[-1][0] if cache else 0 + + # Adjust start to be positive + i = max(len_iter + start, 0) + + # Adjust stop to be positive + if stop is None: + j = len_iter + elif stop >= 0: + j = min(stop, len_iter) + else: + j = max(len_iter + stop, 0) + + # Slice the cache + n = j - i + if n <= 0: + return + + for index, item in islice(cache, 0, n, step): + yield item + elif (stop is not None) and (stop < 0): + # Advance to the start position + next(islice(it, start, start), None) + + # When stop is negative, we have to carry -stop items while + # iterating + cache = deque(islice(it, -stop), maxlen=-stop) + + for index, item in enumerate(it): + cached_item = cache.popleft() + if index % step == 0: + yield cached_item + cache.append(item) + else: + # When both start and stop are positive we have the normal case + yield from islice(it, start, stop, step) + else: + start = -1 if (start is None) else start + + if (stop is not None) and (stop < 0): + # Consume all but the last items + n = -stop - 1 + cache = deque(enumerate(it, 1), maxlen=n) + len_iter = cache[-1][0] if cache else 0 + + # If start and stop are both negative they are comparable and + # we can just slice. Otherwise we can adjust start to be negative + # and then slice. + if start < 0: + i, j = start, stop + else: + i, j = min(start - len_iter, -1), None + + for index, item in list(cache)[i:j:step]: + yield item + else: + # Advance to the stop position + if stop is not None: + m = stop + 1 + next(islice(it, m, m), None) + + # stop is positive, so if start is negative they are not comparable + # and we need the rest of the items. + if start < 0: + i = start + n = None + # stop is None and start is positive, so we just need items up to + # the start index. + elif stop is None: + i = None + n = start + 1 + # Both stop and start are positive, so they are comparable. + else: + i = None + n = start - stop + if n <= 0: + return + + cache = list(islice(it, n)) + + yield from cache[i::step] + + +def always_reversible(iterable): + """An extension of :func:`reversed` that supports all iterables, not + just those which implement the ``Reversible`` or ``Sequence`` protocols. + + >>> print(*always_reversible(x for x in range(3))) + 2 1 0 + + If the iterable is already reversible, this function returns the + result of :func:`reversed()`. If the iterable is not reversible, + this function will cache the remaining items in the iterable and + yield them in reverse order, which may require significant storage. + """ + try: + return reversed(iterable) + except TypeError: + return reversed(list(iterable)) + + +def consecutive_groups(iterable, ordering=lambda x: x): + """Yield groups of consecutive items using :func:`itertools.groupby`. + The *ordering* function determines whether two items are adjacent by + returning their position. + + By default, the ordering function is the identity function. This is + suitable for finding runs of numbers: + + >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40] + >>> for group in consecutive_groups(iterable): + ... print(list(group)) + [1] + [10, 11, 12] + [20] + [30, 31, 32, 33] + [40] + + For finding runs of adjacent letters, try using the :meth:`index` method + of a string of letters: + + >>> from string import ascii_lowercase + >>> iterable = 'abcdfgilmnop' + >>> ordering = ascii_lowercase.index + >>> for group in consecutive_groups(iterable, ordering): + ... print(list(group)) + ['a', 'b', 'c', 'd'] + ['f', 'g'] + ['i'] + ['l', 'm', 'n', 'o', 'p'] + + Each group of consecutive items is an iterator that shares it source with + *iterable*. When an an output group is advanced, the previous group is + no longer available unless its elements are copied (e.g., into a ``list``). + + >>> iterable = [1, 2, 11, 12, 21, 22] + >>> saved_groups = [] + >>> for group in consecutive_groups(iterable): + ... saved_groups.append(list(group)) # Copy group elements + >>> saved_groups + [[1, 2], [11, 12], [21, 22]] + + """ + for k, g in groupby( + enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) + ): + yield map(itemgetter(1), g) + + +def difference(iterable, func=sub, *, initial=None): + """This function is the inverse of :func:`itertools.accumulate`. By default + it will compute the first difference of *iterable* using + :func:`operator.sub`: + + >>> from itertools import accumulate + >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10 + >>> list(difference(iterable)) + [0, 1, 2, 3, 4] + + *func* defaults to :func:`operator.sub`, but other functions can be + specified. They will be applied as follows:: + + A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... + + For example, to do progressive division: + + >>> iterable = [1, 2, 6, 24, 120] + >>> func = lambda x, y: x // y + >>> list(difference(iterable, func)) + [1, 2, 3, 4, 5] + + If the *initial* keyword is set, the first element will be skipped when + computing successive differences. + + >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10) + >>> list(difference(it, initial=10)) + [1, 2, 3] + + """ + a, b = tee(iterable) + try: + first = [next(b)] + except StopIteration: + return iter([]) + + if initial is not None: + first = [] + + return chain(first, starmap(func, zip(b, a))) + + +class SequenceView(Sequence): + """Return a read-only view of the sequence object *target*. + + :class:`SequenceView` objects are analogous to Python's built-in + "dictionary view" types. They provide a dynamic view of a sequence's items, + meaning that when the sequence updates, so does the view. + + >>> seq = ['0', '1', '2'] + >>> view = SequenceView(seq) + >>> view + SequenceView(['0', '1', '2']) + >>> seq.append('3') + >>> view + SequenceView(['0', '1', '2', '3']) + + Sequence views support indexing, slicing, and length queries. They act + like the underlying sequence, except they don't allow assignment: + + >>> view[1] + '1' + >>> view[1:-1] + ['1', '2'] + >>> len(view) + 4 + + Sequence views are useful as an alternative to copying, as they don't + require (much) extra storage. + + """ + + def __init__(self, target): + if not isinstance(target, Sequence): + raise TypeError + self._target = target + + def __getitem__(self, index): + return self._target[index] + + def __len__(self): + return len(self._target) + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, repr(self._target)) + + +class seekable: + """Wrap an iterator to allow for seeking backward and forward. This + progressively caches the items in the source iterable so they can be + re-visited. + + Call :meth:`seek` with an index to seek to that position in the source + iterable. + + To "reset" an iterator, seek to ``0``: + + >>> from itertools import count + >>> it = seekable((str(n) for n in count())) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> it.seek(0) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> next(it) + '3' + + You can also seek forward: + + >>> it = seekable((str(n) for n in range(20))) + >>> it.seek(10) + >>> next(it) + '10' + >>> it.seek(20) # Seeking past the end of the source isn't a problem + >>> list(it) + [] + >>> it.seek(0) # Resetting works even after hitting the end + >>> next(it), next(it), next(it) + ('0', '1', '2') + + Call :meth:`peek` to look ahead one item without advancing the iterator: + + >>> it = seekable('1234') + >>> it.peek() + '1' + >>> list(it) + ['1', '2', '3', '4'] + >>> it.peek(default='empty') + 'empty' + + Before the iterator is at its end, calling :func:`bool` on it will return + ``True``. After it will return ``False``: + + >>> it = seekable('5678') + >>> bool(it) + True + >>> list(it) + ['5', '6', '7', '8'] + >>> bool(it) + False + + You may view the contents of the cache with the :meth:`elements` method. + That returns a :class:`SequenceView`, a view that updates automatically: + + >>> it = seekable((str(n) for n in range(10))) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> elements = it.elements() + >>> elements + SequenceView(['0', '1', '2']) + >>> next(it) + '3' + >>> elements + SequenceView(['0', '1', '2', '3']) + + By default, the cache grows as the source iterable progresses, so beware of + wrapping very large or infinite iterables. Supply *maxlen* to limit the + size of the cache (this of course limits how far back you can seek). + + >>> from itertools import count + >>> it = seekable((str(n) for n in count()), maxlen=2) + >>> next(it), next(it), next(it), next(it) + ('0', '1', '2', '3') + >>> list(it.elements()) + ['2', '3'] + >>> it.seek(0) + >>> next(it), next(it), next(it), next(it) + ('2', '3', '4', '5') + >>> next(it) + '6' + + """ + + def __init__(self, iterable, maxlen=None): + self._source = iter(iterable) + if maxlen is None: + self._cache = [] + else: + self._cache = deque([], maxlen) + self._index = None + + def __iter__(self): + return self + + def __next__(self): + if self._index is not None: + try: + item = self._cache[self._index] + except IndexError: + self._index = None + else: + self._index += 1 + return item + + item = next(self._source) + self._cache.append(item) + return item + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + try: + peeked = next(self) + except StopIteration: + if default is _marker: + raise + return default + if self._index is None: + self._index = len(self._cache) + self._index -= 1 + return peeked + + def elements(self): + return SequenceView(self._cache) + + def seek(self, index): + self._index = index + remainder = index - len(self._cache) + if remainder > 0: + consume(self, remainder) + + +class run_length: + """ + :func:`run_length.encode` compresses an iterable with run-length encoding. + It yields groups of repeated items with the count of how many times they + were repeated: + + >>> uncompressed = 'abbcccdddd' + >>> list(run_length.encode(uncompressed)) + [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + + :func:`run_length.decode` decompresses an iterable that was previously + compressed with run-length encoding. It yields the items of the + decompressed iterable: + + >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> list(run_length.decode(compressed)) + ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd'] + + """ + + @staticmethod + def encode(iterable): + return ((k, ilen(g)) for k, g in groupby(iterable)) + + @staticmethod + def decode(iterable): + return chain.from_iterable(repeat(k, n) for k, n in iterable) + + +def exactly_n(iterable, n, predicate=bool): + """Return ``True`` if exactly ``n`` items in the iterable are ``True`` + according to the *predicate* function. + + >>> exactly_n([True, True, False], 2) + True + >>> exactly_n([True, True, False], 1) + False + >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3) + True + + The iterable will be advanced until ``n + 1`` truthy items are encountered, + so avoid calling it on infinite iterables. + + """ + return len(take(n + 1, filter(predicate, iterable))) == n + + +def circular_shifts(iterable): + """Return a list of circular shifts of *iterable*. + + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + """ + lst = list(iterable) + return take(len(lst), windowed(cycle(lst), len(lst))) + + +def make_decorator(wrapping_func, result_index=0): + """Return a decorator version of *wrapping_func*, which is a function that + modifies an iterable. *result_index* is the position in that function's + signature where the iterable goes. + + This lets you use itertools on the "production end," i.e. at function + definition. This can augment what the function returns without changing the + function's code. + + For example, to produce a decorator version of :func:`chunked`: + + >>> from more_itertools import chunked + >>> chunker = make_decorator(chunked, result_index=0) + >>> @chunker(3) + ... def iter_range(n): + ... return iter(range(n)) + ... + >>> list(iter_range(9)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + To only allow truthy items to be returned: + + >>> truth_serum = make_decorator(filter, result_index=1) + >>> @truth_serum(bool) + ... def boolean_test(): + ... return [0, 1, '', ' ', False, True] + ... + >>> list(boolean_test()) + [1, ' ', True] + + The :func:`peekable` and :func:`seekable` wrappers make for practical + decorators: + + >>> from more_itertools import peekable + >>> peekable_function = make_decorator(peekable) + >>> @peekable_function() + ... def str_range(*args): + ... return (str(x) for x in range(*args)) + ... + >>> it = str_range(1, 20, 2) + >>> next(it), next(it), next(it) + ('1', '3', '5') + >>> it.peek() + '7' + >>> next(it) + '7' + + """ + # See https://sites.google.com/site/bbayles/index/decorator_factory for + # notes on how this works. + def decorator(*wrapping_args, **wrapping_kwargs): + def outer_wrapper(f): + def inner_wrapper(*args, **kwargs): + result = f(*args, **kwargs) + wrapping_args_ = list(wrapping_args) + wrapping_args_.insert(result_index, result) + return wrapping_func(*wrapping_args_, **wrapping_kwargs) + + return inner_wrapper + + return outer_wrapper + + return decorator + + +def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None): + """Return a dictionary that maps the items in *iterable* to categories + defined by *keyfunc*, transforms them with *valuefunc*, and + then summarizes them by category with *reducefunc*. + + *valuefunc* defaults to the identity function if it is unspecified. + If *reducefunc* is unspecified, no summarization takes place: + + >>> keyfunc = lambda x: x.upper() + >>> result = map_reduce('abbccc', keyfunc) + >>> sorted(result.items()) + [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])] + + Specifying *valuefunc* transforms the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> result = map_reduce('abbccc', keyfunc, valuefunc) + >>> sorted(result.items()) + [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])] + + Specifying *reducefunc* summarizes the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> reducefunc = sum + >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc) + >>> sorted(result.items()) + [('A', 1), ('B', 2), ('C', 3)] + + You may want to filter the input iterable before applying the map/reduce + procedure: + + >>> all_items = range(30) + >>> items = [x for x in all_items if 10 <= x <= 20] # Filter + >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1 + >>> categories = map_reduce(items, keyfunc=keyfunc) + >>> sorted(categories.items()) + [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])] + >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum) + >>> sorted(summaries.items()) + [(0, 90), (1, 75)] + + Note that all items in the iterable are gathered into a list before the + summarization step, which may require significant storage. + + The returned object is a :obj:`collections.defaultdict` with the + ``default_factory`` set to ``None``, such that it behaves like a normal + dictionary. + + """ + valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc + + ret = defaultdict(list) + for item in iterable: + key = keyfunc(item) + value = valuefunc(item) + ret[key].append(value) + + if reducefunc is not None: + for key, value_list in ret.items(): + ret[key] = reducefunc(value_list) + + ret.default_factory = None + return ret + + +def rlocate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``, starting from the right and moving left. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4 + [4, 2, 1] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item: + + >>> iterable = iter('abcb') + >>> pred = lambda x: x == 'b' + >>> list(rlocate(iterable, pred)) + [3, 1] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(rlocate(iterable, pred=pred, window_size=3)) + [9, 5, 1] + + Beware, this function won't return anything for infinite iterables. + If *iterable* is reversible, ``rlocate`` will reverse it and search from + the right. Otherwise, it will search from the left and return the results + in reverse order. + + See :func:`locate` to for other example applications. + + """ + if window_size is None: + try: + len_iter = len(iterable) + return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) + except TypeError: + pass + + return reversed(list(locate(iterable, pred, window_size))) + + +def replace(iterable, pred, substitutes, count=None, window_size=1): + """Yield the items from *iterable*, replacing the items for which *pred* + returns ``True`` with the items from the iterable *substitutes*. + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1] + >>> pred = lambda x: x == 0 + >>> substitutes = (2, 3) + >>> list(replace(iterable, pred, substitutes)) + [1, 1, 2, 3, 1, 1, 2, 3, 1, 1] + + If *count* is given, the number of replacements will be limited: + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0] + >>> pred = lambda x: x == 0 + >>> substitutes = [None] + >>> list(replace(iterable, pred, substitutes, count=2)) + [1, 1, None, 1, 1, None, 1, 1, 0] + + Use *window_size* to control the number of items passed as arguments to + *pred*. This allows for locating and replacing subsequences. + + >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5] + >>> window_size = 3 + >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred + >>> substitutes = [3, 4] # Splice in these items + >>> list(replace(iterable, pred, substitutes, window_size=window_size)) + [3, 4, 5, 3, 4, 5] + + """ + if window_size < 1: + raise ValueError('window_size must be at least 1') + + # Save the substitutes iterable, since it's used more than once + substitutes = tuple(substitutes) + + # Add padding such that the number of windows matches the length of the + # iterable + it = chain(iterable, [_marker] * (window_size - 1)) + windows = windowed(it, window_size) + + n = 0 + for w in windows: + # If the current window matches our predicate (and we haven't hit + # our maximum number of replacements), splice in the substitutes + # and then consume the following windows that overlap with this one. + # For example, if the iterable is (0, 1, 2, 3, 4...) + # and the window size is 2, we have (0, 1), (1, 2), (2, 3)... + # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2) + if pred(*w): + if (count is None) or (n < count): + n += 1 + yield from substitutes + consume(windows, window_size - 1) + continue + + # If there was no match (or we've reached the replacement limit), + # yield the first item from the window. + if w and (w[0] is not _marker): + yield w[0] + + +def partitions(iterable): + """Yield all possible order-preserving partitions of *iterable*. + + >>> iterable = 'abc' + >>> for part in partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['a', 'b', 'c'] + + This is unrelated to :func:`partition`. + + """ + sequence = list(iterable) + n = len(sequence) + for i in powerset(range(1, n)): + yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] + + +def set_partitions(iterable, k=None): + """ + Yield the set partitions of *iterable* into *k* parts. Set partitions are + not order-preserving. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, 2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + + + If *k* is not given, every set partition is generated. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + + """ + L = list(iterable) + n = len(L) + if k is not None: + if k < 1: + raise ValueError( + "Can't partition in a negative or zero number of groups" + ) + elif k > n: + return + + def set_partitions_helper(L, k): + n = len(L) + if k == 1: + yield [L] + elif n == k: + yield [[s] for s in L] + else: + e, *M = L + for p in set_partitions_helper(M, k - 1): + yield [[e], *p] + for p in set_partitions_helper(M, k): + for i in range(len(p)): + yield p[:i] + [[e] + p[i]] + p[i + 1 :] + + if k is None: + for k in range(1, n + 1): + yield from set_partitions_helper(L, k) + else: + yield from set_partitions_helper(L, k) + + +class time_limited: + """ + Yield items from *iterable* until *limit_seconds* have passed. + If the time limit expires before all items have been yielded, the + ``timed_out`` parameter will be set to ``True``. + + >>> from time import sleep + >>> def generator(): + ... yield 1 + ... yield 2 + ... sleep(0.2) + ... yield 3 + >>> iterable = time_limited(0.1, generator()) + >>> list(iterable) + [1, 2] + >>> iterable.timed_out + True + + Note that the time is checked before each item is yielded, and iteration + stops if the time elapsed is greater than *limit_seconds*. If your time + limit is 1 second, but it takes 2 seconds to generate the first item from + the iterable, the function will run for 2 seconds and not yield anything. + + """ + + def __init__(self, limit_seconds, iterable): + if limit_seconds < 0: + raise ValueError('limit_seconds must be positive') + self.limit_seconds = limit_seconds + self._iterable = iter(iterable) + self._start_time = monotonic() + self.timed_out = False + + def __iter__(self): + return self + + def __next__(self): + item = next(self._iterable) + if monotonic() - self._start_time > self.limit_seconds: + self.timed_out = True + raise StopIteration + + return item + + +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + first_value = next(it, default) + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def ichunked(iterable, n): + """Break *iterable* into sub-iterables with *n* elements each. + :func:`ichunked` is like :func:`chunked`, but it yields iterables + instead of lists. + + If the sub-iterables are read in order, the elements of *iterable* + won't be stored in memory. + If they are read out of order, :func:`itertools.tee` is used to cache + elements as necessary. + + >>> from itertools import count + >>> all_chunks = ichunked(count(), 4) + >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) + >>> list(c_2) # c_1's elements have been cached; c_3's haven't been + [4, 5, 6, 7] + >>> list(c_1) + [0, 1, 2, 3] + >>> list(c_3) + [8, 9, 10, 11] + + """ + source = iter(iterable) + + while True: + # Check to see whether we're at the end of the source iterable + item = next(source, _marker) + if item is _marker: + return + + # Clone the source and yield an n-length slice + source, it = tee(chain([item], source)) + yield islice(it, n) + + # Advance the source iterable + consume(source, n) + + +def distinct_combinations(iterable, r): + """Yield the distinct combinations of *r* items taken from *iterable*. + + >>> list(distinct_combinations([0, 0, 1], 2)) + [(0, 0), (0, 1)] + + Equivalent to ``set(combinations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + """ + if r < 0: + raise ValueError('r must be non-negative') + elif r == 0: + yield () + return + pool = tuple(iterable) + generators = [unique_everseen(enumerate(pool), key=itemgetter(1))] + current_combo = [None] * r + level = 0 + while generators: + try: + cur_idx, p = next(generators[-1]) + except StopIteration: + generators.pop() + level -= 1 + continue + current_combo[level] = p + if level + 1 == r: + yield tuple(current_combo) + else: + generators.append( + unique_everseen( + enumerate(pool[cur_idx + 1 :], cur_idx + 1), + key=itemgetter(1), + ) + ) + level += 1 + + +def filter_except(validator, iterable, *exceptions): + """Yield the items from *iterable* for which the *validator* function does + not raise one of the specified *exceptions*. + + *validator* is called for each item in *iterable*. + It should be a function that accepts one argument and raises an exception + if that item is not valid. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(filter_except(int, iterable, ValueError, TypeError)) + ['1', '2', '4'] + + If an exception other than one given by *exceptions* is raised by + *validator*, it is raised like normal. + """ + for item in iterable: + try: + validator(item) + except exceptions: + pass + else: + yield item + + +def map_except(function, iterable, *exceptions): + """Transform each item from *iterable* with *function* and yield the + result, unless *function* raises one of the specified *exceptions*. + + *function* is called to transform each item in *iterable*. + It should be a accept one argument. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(map_except(int, iterable, ValueError, TypeError)) + [1, 2, 4] + + If an exception other than one given by *exceptions* is raised by + *function*, it is raised like normal. + """ + for item in iterable: + try: + yield function(item) + except exceptions: + pass + + +def _sample_unweighted(iterable, k): + # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: + # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". + + # Fill up the reservoir (collection of samples) with the first `k` samples + reservoir = take(k, iterable) + + # Generate random number that's the largest in a sample of k U(0,1) numbers + # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic + W = exp(log(random()) / k) + + # The number of elements to skip before changing the reservoir is a random + # number with a geometric distribution. Sample it using random() and logs. + next_index = k + floor(log(random()) / log(1 - W)) + + for index, element in enumerate(iterable, k): + + if index == next_index: + reservoir[randrange(k)] = element + # The new W is the largest in a sample of k U(0, `old_W`) numbers + W *= exp(log(random()) / k) + next_index += floor(log(random()) / log(1 - W)) + 1 + + return reservoir + + +def _sample_weighted(iterable, k, weights): + # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : + # "Weighted random sampling with a reservoir". + + # Log-transform for numerical stability for weights that are small/large + weight_keys = (log(random()) / weight for weight in weights) + + # Fill up the reservoir (collection of samples) with the first `k` + # weight-keys and elements, then heapify the list. + reservoir = take(k, zip(weight_keys, iterable)) + heapify(reservoir) + + # The number of jumps before changing the reservoir is a random variable + # with an exponential distribution. Sample it using random() and logs. + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + + for weight, element in zip(weights, iterable): + if weight >= weights_to_skip: + # The notation here is consistent with the paper, but we store + # the weight-keys in log-space for better numerical stability. + smallest_weight_key, _ = reservoir[0] + t_w = exp(weight * smallest_weight_key) + r_2 = uniform(t_w, 1) # generate U(t_w, 1) + weight_key = log(r_2) / weight + heapreplace(reservoir, (weight_key, element)) + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + else: + weights_to_skip -= weight + + # Equivalent to [element for weight_key, element in sorted(reservoir)] + return [heappop(reservoir)[1] for _ in range(k)] + + +def sample(iterable, k, weights=None): + """Return a *k*-length list of elements chosen (without replacement) + from the *iterable*. Like :func:`random.sample`, but works on iterables + of unknown length. + + >>> iterable = range(100) + >>> sample(iterable, 5) # doctest: +SKIP + [81, 60, 96, 16, 4] + + An iterable with *weights* may also be given: + + >>> iterable = range(100) + >>> weights = (i * i + 1 for i in range(100)) + >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP + [79, 67, 74, 66, 78] + + The algorithm can also be used to generate weighted random permutations. + The relative weight of each item determines the probability that it + appears late in the permutation. + + >>> data = "abcdefgh" + >>> weights = range(1, len(data) + 1) + >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP + ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f'] + """ + if k == 0: + return [] + + iterable = iter(iterable) + if weights is None: + return _sample_unweighted(iterable, k) + else: + weights = iter(weights) + return _sample_weighted(iterable, k, weights) + + +def is_sorted(iterable, key=None, reverse=False): + """Returns ``True`` if the items of iterable are in sorted order, and + ``False`` otherwise. *key* and *reverse* have the same meaning that they do + in the built-in :func:`sorted` function. + + >>> is_sorted(['1', '2', '3', '4', '5'], key=int) + True + >>> is_sorted([5, 4, 3, 1, 2], reverse=True) + False + + The function returns ``False`` after encountering the first out-of-order + item. If there are no out-of-order items, the iterable is exhausted. + """ + + compare = lt if reverse else gt + it = iterable if (key is None) else map(key, iterable) + return not any(starmap(compare, pairwise(it))) + + +class AbortThread(BaseException): + pass + + +class callback_iter: + """Convert a function that uses callbacks to an iterator. + + Let *func* be a function that takes a `callback` keyword argument. + For example: + + >>> def func(callback=None): + ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]: + ... if callback: + ... callback(i, c) + ... return 4 + + + Use ``with callback_iter(func)`` to get an iterator over the parameters + that are delivered to the callback. + + >>> with callback_iter(func) as it: + ... for args, kwargs in it: + ... print(args) + (1, 'a') + (2, 'b') + (3, 'c') + + The function will be called in a background thread. The ``done`` property + indicates whether it has completed execution. + + >>> it.done + True + + If it completes successfully, its return value will be available + in the ``result`` property. + + >>> it.result + 4 + + Notes: + + * If the function uses some keyword argument besides ``callback``, supply + *callback_kwd*. + * If it finished executing, but raised an exception, accessing the + ``result`` property will raise the same exception. + * If it hasn't finished executing, accessing the ``result`` + property from within the ``with`` block will raise ``RuntimeError``. + * If it hasn't finished executing, accessing the ``result`` property from + outside the ``with`` block will raise a + ``more_itertools.AbortThread`` exception. + * Provide *wait_seconds* to adjust how frequently the it is polled for + output. + + """ + + def __init__(self, func, callback_kwd='callback', wait_seconds=0.1): + self._func = func + self._callback_kwd = callback_kwd + self._aborted = False + self._future = None + self._wait_seconds = wait_seconds + self._executor = ThreadPoolExecutor(max_workers=1) + self._iterator = self._reader() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._aborted = True + self._executor.shutdown() + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterator) + + @property + def done(self): + if self._future is None: + return False + return self._future.done() + + @property + def result(self): + if not self.done: + raise RuntimeError('Function has not yet completed') + + return self._future.result() + + def _reader(self): + q = Queue() + + def callback(*args, **kwargs): + if self._aborted: + raise AbortThread('canceled by user') + + q.put((args, kwargs)) + + self._future = self._executor.submit( + self._func, **{self._callback_kwd: callback} + ) + + while True: + try: + item = q.get(timeout=self._wait_seconds) + except Empty: + pass + else: + q.task_done() + yield item + + if self._future.done(): + break + + remaining = [] + while True: + try: + item = q.get_nowait() + except Empty: + break + else: + q.task_done() + remaining.append(item) + q.join() + yield from remaining + + +def windowed_complete(iterable, n): + """ + Yield ``(beginning, middle, end)`` tuples, where: + + * Each ``middle`` has *n* items from *iterable* + * Each ``beginning`` has the items before the ones in ``middle`` + * Each ``end`` has the items after the ones in ``middle`` + + >>> iterable = range(7) + >>> n = 3 + >>> for beginning, middle, end in windowed_complete(iterable, n): + ... print(beginning, middle, end) + () (0, 1, 2) (3, 4, 5, 6) + (0,) (1, 2, 3) (4, 5, 6) + (0, 1) (2, 3, 4) (5, 6) + (0, 1, 2) (3, 4, 5) (6,) + (0, 1, 2, 3) (4, 5, 6) () + + Note that *n* must be at least 0 and most equal to the length of + *iterable*. + + This function will exhaust the iterable and may require significant + storage. + """ + if n < 0: + raise ValueError('n must be >= 0') + + seq = tuple(iterable) + size = len(seq) + + if n > size: + raise ValueError('n must be <= len(seq)') + + for i in range(size - n + 1): + beginning = seq[:i] + middle = seq[i : i + n] + end = seq[i + n :] + yield beginning, middle, end + + +def all_unique(iterable, key=None): + """ + Returns ``True`` if all the elements of *iterable* are unique (no two + elements are equal). + + >>> all_unique('ABCB') + False + + If a *key* function is specified, it will be used to make comparisons. + + >>> all_unique('ABCb') + True + >>> all_unique('ABCb', str.lower) + False + + The function returns as soon as the first non-unique element is + encountered. Iterables with a mix of hashable and unhashable items can + be used, but the function will be slower for unhashable items. + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + for element in map(key, iterable) if key else iterable: + try: + if element in seenset: + return False + seenset_add(element) + except TypeError: + if element in seenlist: + return False + seenlist_add(element) + return True + + +def nth_product(index, *args): + """Equivalent to ``list(product(*args))[index]``. + + The products of *args* can be ordered lexicographically. + :func:`nth_product` computes the product at sort position *index* without + computing the previous products. + + >>> nth_product(8, range(2), range(2), range(2), range(2)) + (1, 0, 0, 0) + + ``IndexError`` will be raised if the given *index* is invalid. + """ + pools = list(map(tuple, reversed(args))) + ns = list(map(len, pools)) + + c = reduce(mul, ns) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + result = [] + for pool, n in zip(pools, ns): + result.append(pool[index % n]) + index //= n + + return tuple(reversed(result)) + + +def nth_permutation(iterable, r, index): + """Equivalent to ``list(permutations(iterable, r))[index]``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`nth_permutation` + computes the subsequence at sort position *index* directly, without + computing the previous subsequences. + + >>> nth_permutation('ghijk', 2, 5) + ('h', 'i') + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = list(iterable) + n = len(pool) + + if r is None or r == n: + r, c = n, factorial(n) + elif not 0 <= r < n: + raise ValueError + else: + c = factorial(n) // factorial(n - r) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + if c == 0: + return tuple() + + result = [0] * r + q = index * factorial(n) // c if r < n else index + for d in range(1, n + 1): + q, i = divmod(q, d) + if 0 <= n - d < r: + result[n - d] = i + if q == 0: + break + + return tuple(map(pool.pop, result)) + + +def value_chain(*args): + """Yield all arguments passed to the function in the same order in which + they were passed. If an argument itself is iterable then iterate over its + values. + + >>> list(value_chain(1, 2, 3, [4, 5, 6])) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and are emitted + as-is: + + >>> list(value_chain('12', '34', ['56', '78'])) + ['12', '34', '56', '78'] + + + Multiple levels of nesting are not flattened. + + """ + for value in args: + if isinstance(value, (str, bytes)): + yield value + continue + try: + yield from value + except TypeError: + yield value + + +def product_index(element, *args): + """Equivalent to ``list(product(*args)).index(element)`` + + The products of *args* can be ordered lexicographically. + :func:`product_index` computes the first index of *element* without + computing the previous products. + + >>> product_index([8, 2], range(10), range(5)) + 42 + + ``ValueError`` will be raised if the given *element* isn't in the product + of *args*. + """ + index = 0 + + for x, pool in zip_longest(element, args, fillvalue=_marker): + if x is _marker or pool is _marker: + raise ValueError('element is not a product of args') + + pool = tuple(pool) + index = index * len(pool) + pool.index(x) + + return index + + +def combination_index(element, iterable): + """Equivalent to ``list(combinations(iterable, r)).index(element)`` + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`combination_index` computes the index of the + first *element*, without computing the previous combinations. + + >>> combination_index('adf', 'abcdefg') + 10 + + ``ValueError`` will be raised if the given *element* isn't one of the + combinations of *iterable*. + """ + element = enumerate(element) + k, y = next(element, (None, None)) + if k is None: + return 0 + + indexes = [] + pool = enumerate(iterable) + for n, x in pool: + if x == y: + indexes.append(n) + tmp, y = next(element, (None, None)) + if tmp is None: + break + else: + k = tmp + else: + raise ValueError('element is not a combination of iterable') + + n, _ = last(pool, default=(n, None)) + + # Python versiosn below 3.8 don't have math.comb + index = 1 + for i, j in enumerate(reversed(indexes), start=1): + j = n - j + if i <= j: + index += factorial(j) // (factorial(i) * factorial(j - i)) + + return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index + + +def permutation_index(element, iterable): + """Equivalent to ``list(permutations(iterable, r)).index(element)``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`permutation_index` + computes the index of the first *element* directly, without computing + the previous permutations. + + >>> permutation_index([1, 3, 2], range(5)) + 19 + + ``ValueError`` will be raised if the given *element* isn't one of the + permutations of *iterable*. + """ + index = 0 + pool = list(iterable) + for i, x in zip(range(len(pool), -1, -1), element): + r = pool.index(x) + index = index * i + r + del pool[r] + + return index + + +class countable: + """Wrap *iterable* and keep a count of how many items have been consumed. + + The ``items_seen`` attribute starts at ``0`` and increments as the iterable + is consumed: + + >>> iterable = map(str, range(10)) + >>> it = countable(iterable) + >>> it.items_seen + 0 + >>> next(it), next(it) + ('0', '1') + >>> list(it) + ['2', '3', '4', '5', '6', '7', '8', '9'] + >>> it.items_seen + 10 + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self.items_seen = 0 + + def __iter__(self): + return self + + def __next__(self): + item = next(self._it) + self.items_seen += 1 + + return item diff --git a/setuptools/_vendor/more_itertools/more.pyi b/setuptools/_vendor/more_itertools/more.pyi new file mode 100644 index 00000000..2fba9cb3 --- /dev/null +++ b/setuptools/_vendor/more_itertools/more.pyi @@ -0,0 +1,480 @@ +"""Stubs for more_itertools.more""" + +from typing import ( + Any, + Callable, + Container, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + List, + Optional, + Reversible, + Sequence, + Sized, + Tuple, + Union, + TypeVar, + type_check_only, +) +from types import TracebackType +from typing_extensions import ContextManager, Protocol, Type, overload + +# Type and type variable definitions +_T = TypeVar('_T') +_U = TypeVar('_U') +_V = TypeVar('_V') +_W = TypeVar('_W') +_T_co = TypeVar('_T_co', covariant=True) +_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]]) +_Raisable = Union[BaseException, 'Type[BaseException]'] + +@type_check_only +class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ... + +@type_check_only +class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ... + +def chunked( + iterable: Iterable[_T], n: int, strict: bool = ... +) -> Iterator[List[_T]]: ... +@overload +def first(iterable: Iterable[_T]) -> _T: ... +@overload +def first(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... +@overload +def last(iterable: Iterable[_T]) -> _T: ... +@overload +def last(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... +@overload +def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ... +@overload +def nth_or_last( + iterable: Iterable[_T], n: int, default: _U +) -> Union[_T, _U]: ... + +class peekable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> peekable[_T]: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> Union[_T, _U]: ... + def prepend(self, *items: _T) -> None: ... + def __next__(self) -> _T: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + +def collate(*iterables: Iterable[_T], **kwargs: Any) -> Iterable[_T]: ... +def consumer(func: _GenFn) -> _GenFn: ... +def ilen(iterable: Iterable[object]) -> int: ... +def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... +def with_iter( + context_manager: ContextManager[Iterable[_T]], +) -> Iterator[_T]: ... +def one( + iterable: Iterable[_T], + too_short: Optional[_Raisable] = ..., + too_long: Optional[_Raisable] = ..., +) -> _T: ... +def distinct_permutations( + iterable: Iterable[_T], r: Optional[int] = ... +) -> Iterator[Tuple[_T, ...]]: ... +def intersperse( + e: _U, iterable: Iterable[_T], n: int = ... +) -> Iterator[Union[_T, _U]]: ... +def unique_to_each(*iterables: Iterable[_T]) -> List[List[_T]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, *, step: int = ... +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, fillvalue: _U, step: int = ... +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def substrings(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... +def substrings_indexes( + seq: Sequence[_T], reverse: bool = ... +) -> Iterator[Tuple[Sequence[_T], int, int]]: ... + +class bucket(Generic[_T, _U], Container[_U]): + def __init__( + self, + iterable: Iterable[_T], + key: Callable[[_T], _U], + validator: Optional[Callable[[object], object]] = ..., + ) -> None: ... + def __contains__(self, value: object) -> bool: ... + def __iter__(self) -> Iterator[_U]: ... + def __getitem__(self, value: object) -> Iterator[_T]: ... + +def spy( + iterable: Iterable[_T], n: int = ... +) -> Tuple[List[_T], Iterator[_T]]: ... +def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def collapse( + iterable: Iterable[Any], + base_type: Optional[type] = ..., + levels: Optional[int] = ..., +) -> Iterator[Any]: ... +@overload +def side_effect( + func: Callable[[_T], object], + iterable: Iterable[_T], + chunk_size: None = ..., + before: Optional[Callable[[], object]] = ..., + after: Optional[Callable[[], object]] = ..., +) -> Iterator[_T]: ... +@overload +def side_effect( + func: Callable[[List[_T]], object], + iterable: Iterable[_T], + chunk_size: int, + before: Optional[Callable[[], object]] = ..., + after: Optional[Callable[[], object]] = ..., +) -> Iterator[_T]: ... +def sliced( + seq: Sequence[_T], n: int, strict: bool = ... +) -> Iterator[Sequence[_T]]: ... +def split_at( + iterable: Iterable[_T], + pred: Callable[[_T], object], + maxsplit: int = ..., + keep_separator: bool = ..., +) -> Iterator[List[_T]]: ... +def split_before( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[List[_T]]: ... +def split_after( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[List[_T]]: ... +def split_when( + iterable: Iterable[_T], + pred: Callable[[_T, _T], object], + maxsplit: int = ..., +) -> Iterator[List[_T]]: ... +def split_into( + iterable: Iterable[_T], sizes: Iterable[Optional[int]] +) -> Iterator[List[_T]]: ... +@overload +def padded( + iterable: Iterable[_T], + *, + n: Optional[int] = ..., + next_multiple: bool = ... +) -> Iterator[Optional[_T]]: ... +@overload +def padded( + iterable: Iterable[_T], + fillvalue: _U, + n: Optional[int] = ..., + next_multiple: bool = ..., +) -> Iterator[Union[_T, _U]]: ... +@overload +def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ... +@overload +def repeat_last( + iterable: Iterable[_T], default: _U +) -> Iterator[Union[_T, _U]]: ... +def distribute(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., + fillvalue: _U = ..., +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... + +class UnequalIterablesError(ValueError): + def __init__( + self, details: Optional[Tuple[int, int, int]] = ... + ) -> None: ... + +def zip_equal(*iterables: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... +@overload +def zip_offset( + *iterables: Iterable[_T], offsets: _SizedIterable[int], longest: bool = ... +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def zip_offset( + *iterables: Iterable[_T], + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def sort_together( + iterables: Iterable[Iterable[_T]], + key_list: Iterable[int] = ..., + key: Optional[Callable[..., Any]] = ..., + reverse: bool = ..., +) -> List[Tuple[_T, ...]]: ... +def unzip(iterable: Iterable[Sequence[_T]]) -> Tuple[Iterator[_T], ...]: ... +def divide(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... +def always_iterable( + obj: object, + base_type: Union[ + type, Tuple[Union[type, Tuple[Any, ...]], ...], None + ] = ..., +) -> Iterator[Any]: ... +def adjacent( + predicate: Callable[[_T], bool], + iterable: Iterable[_T], + distance: int = ..., +) -> Iterator[Tuple[bool, _T]]: ... +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Optional[Callable[[_T], _U]] = ..., + valuefunc: Optional[Callable[[_T], _V]] = ..., + reducefunc: Optional[Callable[..., _W]] = ..., +) -> Iterator[Tuple[_T, _W]]: ... + +class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]): + @overload + def __init__(self, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ... + def __bool__(self) -> bool: ... + def __contains__(self, elem: object) -> bool: ... + def __eq__(self, other: object) -> bool: ... + @overload + def __getitem__(self, key: int) -> _T: ... + @overload + def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ... + def __hash__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def __reduce__( + self, + ) -> Tuple[Type[numeric_range[_T, _U]], Tuple[_T, _T, _U]]: ... + def __repr__(self) -> str: ... + def __reversed__(self) -> Iterator[_T]: ... + def count(self, value: _T) -> int: ... + def index(self, value: _T) -> int: ... # type: ignore + +def count_cycle( + iterable: Iterable[_T], n: Optional[int] = ... +) -> Iterable[Tuple[int, _T]]: ... +def mark_ends( + iterable: Iterable[_T], +) -> Iterable[Tuple[bool, bool, _T]]: ... +def locate( + iterable: Iterable[object], + pred: Callable[..., Any] = ..., + window_size: Optional[int] = ..., +) -> Iterator[int]: ... +def lstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def rstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def strip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... + +class islice_extended(Generic[_T], Iterator[_T]): + def __init__( + self, iterable: Iterable[_T], *args: Optional[int] + ) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + def __getitem__(self, index: slice) -> islice_extended[_T]: ... + +def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ... +def consecutive_groups( + iterable: Iterable[_T], ordering: Callable[[_T], int] = ... +) -> Iterator[Iterator[_T]]: ... +@overload +def difference( + iterable: Iterable[_T], + func: Callable[[_T, _T], _U] = ..., + *, + initial: None = ... +) -> Iterator[Union[_T, _U]]: ... +@overload +def difference( + iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U +) -> Iterator[_U]: ... + +class SequenceView(Generic[_T], Sequence[_T]): + def __init__(self, target: Sequence[_T]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> Sequence[_T]: ... + def __len__(self) -> int: ... + +class seekable(Generic[_T], Iterator[_T]): + def __init__( + self, iterable: Iterable[_T], maxlen: Optional[int] = ... + ) -> None: ... + def __iter__(self) -> seekable[_T]: ... + def __next__(self) -> _T: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> Union[_T, _U]: ... + def elements(self) -> SequenceView[_T]: ... + def seek(self, index: int) -> None: ... + +class run_length: + @staticmethod + def encode(iterable: Iterable[_T]) -> Iterator[Tuple[_T, int]]: ... + @staticmethod + def decode(iterable: Iterable[Tuple[_T, int]]) -> Iterator[_T]: ... + +def exactly_n( + iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ... +) -> bool: ... +def circular_shifts(iterable: Iterable[_T]) -> List[Tuple[_T, ...]]: ... +def make_decorator( + wrapping_func: Callable[..., _U], result_index: int = ... +) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: None = ..., +) -> Dict[_U, List[_T]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: None = ..., +) -> Dict[_U, List[_V]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: Callable[[List[_T]], _W] = ..., +) -> Dict[_U, _W]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[List[_V]], _W], +) -> Dict[_U, _W]: ... +def rlocate( + iterable: Iterable[_T], + pred: Callable[..., object] = ..., + window_size: Optional[int] = ..., +) -> Iterator[int]: ... +def replace( + iterable: Iterable[_T], + pred: Callable[..., object], + substitutes: Iterable[_U], + count: Optional[int] = ..., + window_size: int = ..., +) -> Iterator[Union[_T, _U]]: ... +def partitions(iterable: Iterable[_T]) -> Iterator[List[List[_T]]]: ... +def set_partitions( + iterable: Iterable[_T], k: Optional[int] = ... +) -> Iterator[List[List[_T]]]: ... + +class time_limited(Generic[_T], Iterator[_T]): + def __init__( + self, limit_seconds: float, iterable: Iterable[_T] + ) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + +@overload +def only( + iterable: Iterable[_T], *, too_long: Optional[_Raisable] = ... +) -> Optional[_T]: ... +@overload +def only( + iterable: Iterable[_T], default: _U, too_long: Optional[_Raisable] = ... +) -> Union[_T, _U]: ... +def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ... +def distinct_combinations( + iterable: Iterable[_T], r: int +) -> Iterator[Tuple[_T, ...]]: ... +def filter_except( + validator: Callable[[Any], object], + iterable: Iterable[_T], + *exceptions: Type[BaseException] +) -> Iterator[_T]: ... +def map_except( + function: Callable[[Any], _U], + iterable: Iterable[_T], + *exceptions: Type[BaseException] +) -> Iterator[_U]: ... +def sample( + iterable: Iterable[_T], + k: int, + weights: Optional[Iterable[float]] = ..., +) -> List[_T]: ... +def is_sorted( + iterable: Iterable[_T], + key: Optional[Callable[[_T], _U]] = ..., + reverse: bool = False, +) -> bool: ... + +class AbortThread(BaseException): + pass + +class callback_iter(Generic[_T], Iterator[_T]): + def __init__( + self, + func: Callable[..., Any], + callback_kwd: str = ..., + wait_seconds: float = ..., + ) -> None: ... + def __enter__(self) -> callback_iter[_T]: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: ... + def __iter__(self) -> callback_iter[_T]: ... + def __next__(self) -> _T: ... + def _reader(self) -> Iterator[_T]: ... + @property + def done(self) -> bool: ... + @property + def result(self) -> Any: ... + +def windowed_complete( + iterable: Iterable[_T], n: int +) -> Iterator[Tuple[_T, ...]]: ... +def all_unique( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> bool: ... +def nth_product(index: int, *args: Iterable[_T]) -> Tuple[_T, ...]: ... +def nth_permutation( + iterable: Iterable[_T], r: int, index: int +) -> Tuple[_T, ...]: ... +def value_chain(*args: Union[_T, Iterable[_T]]) -> Iterable[_T]: ... +def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ... +def combination_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def permutation_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... + +class countable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> countable[_T]: ... + def __next__(self) -> _T: ... diff --git a/setuptools/_vendor/more_itertools/py.typed b/setuptools/_vendor/more_itertools/py.typed new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/setuptools/_vendor/more_itertools/py.typed diff --git a/setuptools/_vendor/more_itertools/recipes.py b/setuptools/_vendor/more_itertools/recipes.py new file mode 100644 index 00000000..521abd7c --- /dev/null +++ b/setuptools/_vendor/more_itertools/recipes.py @@ -0,0 +1,620 @@ +"""Imported from the recipes section of the itertools documentation. + +All functions taken from the recipes section of the itertools library docs +[1]_. +Some backward-compatible usability improvements have been made. + +.. [1] http://docs.python.org/library/itertools.html#recipes + +""" +import warnings +from collections import deque +from itertools import ( + chain, + combinations, + count, + cycle, + groupby, + islice, + repeat, + starmap, + tee, + zip_longest, +) +import operator +from random import randrange, sample, choice + +__all__ = [ + 'all_equal', + 'consume', + 'convolve', + 'dotproduct', + 'first_true', + 'flatten', + 'grouper', + 'iter_except', + 'ncycles', + 'nth', + 'nth_combination', + 'padnone', + 'pad_none', + 'pairwise', + 'partition', + 'powerset', + 'prepend', + 'quantify', + 'random_combination_with_replacement', + 'random_combination', + 'random_permutation', + 'random_product', + 'repeatfunc', + 'roundrobin', + 'tabulate', + 'tail', + 'take', + 'unique_everseen', + 'unique_justseen', +] + + +def take(n, iterable): + """Return first *n* items of the iterable as a list. + + >>> take(3, range(10)) + [0, 1, 2] + + If there are fewer than *n* items in the iterable, all of them are + returned. + + >>> take(10, range(3)) + [0, 1, 2] + + """ + return list(islice(iterable, n)) + + +def tabulate(function, start=0): + """Return an iterator over the results of ``func(start)``, + ``func(start + 1)``, ``func(start + 2)``... + + *func* should be a function that accepts one integer argument. + + If *start* is not specified it defaults to 0. It will be incremented each + time the iterator is advanced. + + >>> square = lambda x: x ** 2 + >>> iterator = tabulate(square, -3) + >>> take(4, iterator) + [9, 4, 1, 0] + + """ + return map(function, count(start)) + + +def tail(n, iterable): + """Return an iterator over the last *n* items of *iterable*. + + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] + + """ + return iter(deque(iterable, maxlen=n)) + + +def consume(iterator, n=None): + """Advance *iterable* by *n* steps. If *n* is ``None``, consume it + entirely. + + Efficiently exhausts an iterator without returning values. Defaults to + consuming the whole iterator, but an optional second argument may be + provided to limit consumption. + + >>> i = (x for x in range(10)) + >>> next(i) + 0 + >>> consume(i, 3) + >>> next(i) + 4 + >>> consume(i) + >>> next(i) + Traceback (most recent call last): + File "<stdin>", line 1, in <module> + StopIteration + + If the iterator has fewer items remaining than the provided limit, the + whole iterator will be consumed. + + >>> i = (x for x in range(3)) + >>> consume(i, 5) + >>> next(i) + Traceback (most recent call last): + File "<stdin>", line 1, in <module> + StopIteration + + """ + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + + +def nth(iterable, n, default=None): + """Returns the nth item or a default value. + + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' + + """ + return next(islice(iterable, n, None), default) + + +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + + >>> all_equal('aaaa') + True + >>> all_equal('aaab') + False + + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +def quantify(iterable, pred=bool): + """Return the how many times the predicate is true. + + >>> quantify([True, False, True]) + 2 + + """ + return sum(map(pred, iterable)) + + +def pad_none(iterable): + """Returns the sequence of elements and then returns ``None`` indefinitely. + + >>> take(5, pad_none(range(3))) + [0, 1, 2, None, None] + + Useful for emulating the behavior of the built-in :func:`map` function. + + See also :func:`padded`. + + """ + return chain(iterable, repeat(None)) + + +padnone = pad_none + + +def ncycles(iterable, n): + """Returns the sequence elements *n* times + + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] + + """ + return chain.from_iterable(repeat(tuple(iterable), n)) + + +def dotproduct(vec1, vec2): + """Returns the dot product of the two iterables. + + >>> dotproduct([10, 10], [20, 20]) + 400 + + """ + return sum(map(operator.mul, vec1, vec2)) + + +def flatten(listOfLists): + """Return an iterator flattening one level of nesting in a list of lists. + + >>> list(flatten([[0, 1], [2, 3]])) + [0, 1, 2, 3] + + See also :func:`collapse`, which can flatten multiple levels of nesting. + + """ + return chain.from_iterable(listOfLists) + + +def repeatfunc(func, times=None, *args): + """Call *func* with *args* repeatedly, returning an iterable over the + results. + + If *times* is specified, the iterable will terminate after that many + repetitions: + + >>> from operator import add + >>> times = 4 + >>> args = 3, 5 + >>> list(repeatfunc(add, times, *args)) + [8, 8, 8, 8] + + If *times* is ``None`` the iterable will not terminate: + + >>> from random import randrange + >>> times = None + >>> args = 1, 11 + >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP + [2, 4, 8, 1, 8, 4] + + """ + if times is None: + return starmap(func, repeat(args)) + return starmap(func, repeat(args, times)) + + +def _pairwise(iterable): + """Returns an iterator of paired items, overlapping, from the original + + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`. + + """ + a, b = tee(iterable) + next(b, None) + yield from zip(a, b) + + +try: + from itertools import pairwise as itertools_pairwise +except ImportError: + pairwise = _pairwise +else: + + def pairwise(iterable): + yield from itertools_pairwise(iterable) + + pairwise.__doc__ = _pairwise.__doc__ + + +def grouper(iterable, n, fillvalue=None): + """Collect data into fixed-length chunks or blocks. + + >>> list(grouper('ABCDEFG', 3, 'x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + + """ + if isinstance(iterable, int): + warnings.warn( + "grouper expects iterable as first parameter", DeprecationWarning + ) + n, iterable = iterable, n + args = [iter(iterable)] * n + return zip_longest(fillvalue=fillvalue, *args) + + +def roundrobin(*iterables): + """Yields an item from each iterable, alternating between them. + + >>> list(roundrobin('ABC', 'D', 'EF')) + ['A', 'D', 'E', 'B', 'F', 'C'] + + This function produces the same output as :func:`interleave_longest`, but + may perform better for some inputs (in particular when the number of + iterables is small). + + """ + # Recipe credited to George Sakkis + pending = len(iterables) + nexts = cycle(iter(it).__next__ for it in iterables) + while pending: + try: + for next in nexts: + yield next() + except StopIteration: + pending -= 1 + nexts = cycle(islice(nexts, pending)) + + +def partition(pred, iterable): + """ + Returns a 2-tuple of iterables derived from the input iterable. + The first yields the items that have ``pred(item) == False``. + The second yields the items that have ``pred(item) == True``. + + >>> is_odd = lambda x: x % 2 != 0 + >>> iterable = range(10) + >>> even_items, odd_items = partition(is_odd, iterable) + >>> list(even_items), list(odd_items) + ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + + If *pred* is None, :func:`bool` is used. + + >>> iterable = [0, 1, False, True, '', ' '] + >>> false_items, true_items = partition(None, iterable) + >>> list(false_items), list(true_items) + ([0, False, ''], [1, True, ' ']) + + """ + if pred is None: + pred = bool + + evaluations = ((pred(x), x) for x in iterable) + t1, t2 = tee(evaluations) + return ( + (x for (cond, x) in t1 if not cond), + (x for (cond, x) in t2 if cond), + ) + + +def powerset(iterable): + """Yields all possible subsets of the iterable. + + >>> list(powerset([1, 2, 3])) + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + + :func:`powerset` will operate on iterables that aren't :class:`set` + instances, so repeated elements in the input will produce repeated elements + in the output. Use :func:`unique_everseen` on the input to avoid generating + duplicates: + + >>> seq = [1, 1, 0] + >>> list(powerset(seq)) + [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)] + >>> from more_itertools import unique_everseen + >>> list(powerset(unique_everseen(seq))) + [(), (1,), (0,), (1, 0)] + + """ + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def unique_everseen(iterable, key=None): + """ + Yield unique elements, preserving order. + + >>> list(unique_everseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D'] + >>> list(unique_everseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'D'] + + Sequences with a mix of hashable and unhashable items can be used. + The function will be slower (i.e., `O(n^2)`) for unhashable items. + + Remember that ``list`` objects are unhashable - you can use the *key* + parameter to transform the list to a tuple (which is hashable) to + avoid a slowdown. + + >>> iterable = ([1, 2], [2, 3], [1, 2]) + >>> list(unique_everseen(iterable)) # Slow + [[1, 2], [2, 3]] + >>> list(unique_everseen(iterable, key=tuple)) # Faster + [[1, 2], [2, 3]] + + Similary, you may want to convert unhashable ``set`` objects with + ``key=frozenset``. For ``dict`` objects, + ``key=lambda x: frozenset(x.items())`` can be used. + + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element + + +def unique_justseen(iterable, key=None): + """Yields elements in order, ignoring serial duplicates + + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] + + """ + return map(next, map(operator.itemgetter(1), groupby(iterable, key))) + + +def iter_except(func, exception, first=None): + """Yields results from a function repeatedly until an exception is raised. + + Converts a call-until-exception interface to an iterator interface. + Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel + to end the loop. + + >>> l = [0, 1, 2] + >>> list(iter_except(l.pop, IndexError)) + [2, 1, 0] + + """ + try: + if first is not None: + yield first() + while 1: + yield func() + except exception: + pass + + +def first_true(iterable, default=None, pred=None): + """ + Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item for which + ``pred(item) == True`` . + + >>> first_true(range(10)) + 1 + >>> first_true(range(10), pred=lambda x: x > 5) + 6 + >>> first_true(range(10), default='missing', pred=lambda x: x > 9) + 'missing' + + """ + return next(filter(pred, iterable), default) + + +def random_product(*args, repeat=1): + """Draw an item at random from each of the input iterables. + + >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP + ('c', 3, 'Z') + + If *repeat* is provided as a keyword argument, that many items will be + drawn from each iterable. + + >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP + ('a', 2, 'd', 3) + + This equivalent to taking a random selection from + ``itertools.product(*args, **kwarg)``. + + """ + pools = [tuple(pool) for pool in args] * repeat + return tuple(choice(pool) for pool in pools) + + +def random_permutation(iterable, r=None): + """Return a random *r* length permutation of the elements in *iterable*. + + If *r* is not specified or is ``None``, then *r* defaults to the length of + *iterable*. + + >>> random_permutation(range(5)) # doctest:+SKIP + (3, 4, 0, 1, 2) + + This equivalent to taking a random selection from + ``itertools.permutations(iterable, r)``. + + """ + pool = tuple(iterable) + r = len(pool) if r is None else r + return tuple(sample(pool, r)) + + +def random_combination(iterable, r): + """Return a random *r* length subsequence of the elements in *iterable*. + + >>> random_combination(range(5), 3) # doctest:+SKIP + (2, 3, 4) + + This equivalent to taking a random selection from + ``itertools.combinations(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(sample(range(n), r)) + return tuple(pool[i] for i in indices) + + +def random_combination_with_replacement(iterable, r): + """Return a random *r* length subsequence of elements in *iterable*, + allowing individual elements to be repeated. + + >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP + (0, 0, 1, 2, 2) + + This equivalent to taking a random selection from + ``itertools.combinations_with_replacement(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(randrange(n) for i in range(r)) + return tuple(pool[i] for i in indices) + + +def nth_combination(iterable, r, index): + """Equivalent to ``list(combinations(iterable, r))[index]``. + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`nth_combination` computes the subsequence at + sort position *index* directly, without computing the previous + subsequences. + + >>> nth_combination(range(5), 3, 5) + (0, 3, 4) + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = tuple(iterable) + n = len(pool) + if (r < 0) or (r > n): + raise ValueError + + c = 1 + k = min(r, n - r) + for i in range(1, k + 1): + c = c * (n - k + i) // i + + if index < 0: + index += c + + if (index < 0) or (index >= c): + raise IndexError + + result = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + + return tuple(result) + + +def prepend(value, iterator): + """Yield *value*, followed by the elements in *iterator*. + + >>> value = '0' + >>> iterator = ['1', '2', '3'] + >>> list(prepend(value, iterator)) + ['0', '1', '2', '3'] + + To prepend multiple values, see :func:`itertools.chain` + or :func:`value_chain`. + + """ + return chain([value], iterator) + + +def convolve(signal, kernel): + """Convolve the iterable *signal* with the iterable *kernel*. + + >>> signal = (1, 2, 3, 4, 5) + >>> kernel = [3, 2, 1] + >>> list(convolve(signal, kernel)) + [3, 8, 14, 20, 26, 14, 5] + + Note: the input arguments are not interchangeable, as the *kernel* + is immediately consumed and stored. + + """ + kernel = tuple(kernel)[::-1] + n = len(kernel) + window = deque([0], maxlen=n) * n + for x in chain(signal, repeat(0, n - 1)): + window.append(x) + yield sum(map(operator.mul, kernel, window)) diff --git a/setuptools/_vendor/more_itertools/recipes.pyi b/setuptools/_vendor/more_itertools/recipes.pyi new file mode 100644 index 00000000..5e39d963 --- /dev/null +++ b/setuptools/_vendor/more_itertools/recipes.pyi @@ -0,0 +1,103 @@ +"""Stubs for more_itertools.recipes""" +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) +from typing_extensions import overload, Type + +# Type and type variable definitions +_T = TypeVar('_T') +_U = TypeVar('_U') + +def take(n: int, iterable: Iterable[_T]) -> List[_T]: ... +def tabulate( + function: Callable[[int], _T], start: int = ... +) -> Iterator[_T]: ... +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ... +def consume(iterator: Iterable[object], n: Optional[int] = ...) -> None: ... +@overload +def nth(iterable: Iterable[_T], n: int) -> Optional[_T]: ... +@overload +def nth(iterable: Iterable[_T], n: int, default: _U) -> Union[_T, _U]: ... +def all_equal(iterable: Iterable[object]) -> bool: ... +def quantify( + iterable: Iterable[_T], pred: Callable[[_T], bool] = ... +) -> int: ... +def pad_none(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... +def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... +def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ... +def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... +def repeatfunc( + func: Callable[..., _U], times: Optional[int] = ..., *args: Any +) -> Iterator[_U]: ... +def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ... +@overload +def grouper( + iterable: Iterable[_T], n: int +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def grouper( + iterable: Iterable[_T], n: int, fillvalue: _U +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +@overload +def grouper( # Deprecated interface + iterable: int, n: Iterable[_T] +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def grouper( # Deprecated interface + iterable: int, n: Iterable[_T], fillvalue: _U +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def partition( + pred: Optional[Callable[[_T], object]], iterable: Iterable[_T] +) -> Tuple[Iterator[_T], Iterator[_T]]: ... +def powerset(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... +def unique_everseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> Iterator[_T]: ... +def unique_justseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], object]] = ... +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], exception: Type[BaseException], first: None = ... +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], + exception: Type[BaseException], + first: Callable[[], _U], +) -> Iterator[Union[_T, _U]]: ... +@overload +def first_true( + iterable: Iterable[_T], *, pred: Optional[Callable[[_T], object]] = ... +) -> Optional[_T]: ... +@overload +def first_true( + iterable: Iterable[_T], + default: _U, + pred: Optional[Callable[[_T], object]] = ..., +) -> Union[_T, _U]: ... +def random_product( + *args: Iterable[_T], repeat: int = ... +) -> Tuple[_T, ...]: ... +def random_permutation( + iterable: Iterable[_T], r: Optional[int] = ... +) -> Tuple[_T, ...]: ... +def random_combination(iterable: Iterable[_T], r: int) -> Tuple[_T, ...]: ... +def random_combination_with_replacement( + iterable: Iterable[_T], r: int +) -> Tuple[_T, ...]: ... +def nth_combination( + iterable: Iterable[_T], r: int, index: int +) -> Tuple[_T, ...]: ... +def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[Union[_T, _U]]: ... +def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ... diff --git a/setuptools/_vendor/pyparsing.py b/setuptools/_vendor/pyparsing.py index cf75e1e5..4cae7883 100644 --- a/setuptools/_vendor/pyparsing.py +++ b/setuptools/_vendor/pyparsing.py @@ -1625,7 +1625,7 @@ class ParserElement(object): (see L{I{parseWithTabs}<parseWithTabs>})
- define your parse action using the full C{(s,loc,toks)} signature, and
reference the input string using the parse action's C{s} argument
- - explictly expand the tabs in your input string before calling
+ - explicitly expand the tabs in your input string before calling
C{parseString}
Example::
diff --git a/setuptools/_vendor/vendored.txt b/setuptools/_vendor/vendored.txt index b1190436..d9804741 100644 --- a/setuptools/_vendor/vendored.txt +++ b/setuptools/_vendor/vendored.txt @@ -1,3 +1,4 @@ packaging==20.4 pyparsing==2.2.1 ordered-set==3.1.1 +more_itertools==8.8.0 diff --git a/setuptools/archive_util.py b/setuptools/archive_util.py index 0ce190b8..0f702848 100644 --- a/setuptools/archive_util.py +++ b/setuptools/archive_util.py @@ -125,6 +125,56 @@ def unpack_zipfile(filename, extract_dir, progress_filter=default_filter): os.chmod(target, unix_attributes) +def _resolve_tar_file_or_dir(tar_obj, tar_member_obj): + """Resolve any links and extract link targets as normal files.""" + while tar_member_obj is not None and ( + tar_member_obj.islnk() or tar_member_obj.issym()): + linkpath = tar_member_obj.linkname + if tar_member_obj.issym(): + base = posixpath.dirname(tar_member_obj.name) + linkpath = posixpath.join(base, linkpath) + linkpath = posixpath.normpath(linkpath) + tar_member_obj = tar_obj._getmember(linkpath) + + is_file_or_dir = ( + tar_member_obj is not None and + (tar_member_obj.isfile() or tar_member_obj.isdir()) + ) + if is_file_or_dir: + return tar_member_obj + + raise LookupError('Got unknown file type') + + +def _iter_open_tar(tar_obj, extract_dir, progress_filter): + """Emit member-destination pairs from a tar archive.""" + # don't do any chowning! + tar_obj.chown = lambda *args: None + + with contextlib.closing(tar_obj): + for member in tar_obj: + name = member.name + # don't extract absolute paths or ones with .. in them + if name.startswith('/') or '..' in name.split('/'): + continue + + prelim_dst = os.path.join(extract_dir, *name.split('/')) + + try: + member = _resolve_tar_file_or_dir(tar_obj, member) + except LookupError: + continue + + final_dst = progress_filter(name, prelim_dst) + if not final_dst: + continue + + if final_dst.endswith(os.sep): + final_dst = final_dst[:-1] + + yield member, final_dst + + def unpack_tarfile(filename, extract_dir, progress_filter=default_filter): """Unpack tar/tar.gz/tar.bz2 `filename` to `extract_dir` @@ -138,38 +188,18 @@ def unpack_tarfile(filename, extract_dir, progress_filter=default_filter): raise UnrecognizedFormat( "%s is not a compressed or uncompressed tar file" % (filename,) ) from e - with contextlib.closing(tarobj): - # don't do any chowning! - tarobj.chown = lambda *args: None - for member in tarobj: - name = member.name - # don't extract absolute paths or ones with .. in them - if not name.startswith('/') and '..' not in name.split('/'): - prelim_dst = os.path.join(extract_dir, *name.split('/')) - - # resolve any links and to extract the link targets as normal - # files - while member is not None and ( - member.islnk() or member.issym()): - linkpath = member.linkname - if member.issym(): - base = posixpath.dirname(member.name) - linkpath = posixpath.join(base, linkpath) - linkpath = posixpath.normpath(linkpath) - member = tarobj._getmember(linkpath) - - if member is not None and (member.isfile() or member.isdir()): - final_dst = progress_filter(name, prelim_dst) - if final_dst: - if final_dst.endswith(os.sep): - final_dst = final_dst[:-1] - try: - # XXX Ugh - tarobj._extract_member(member, final_dst) - except tarfile.ExtractError: - # chown/chmod/mkfifo/mknode/makedev failed - pass - return True + + for member, final_dst in _iter_open_tar( + tarobj, extract_dir, progress_filter, + ): + try: + # XXX Ugh + tarobj._extract_member(member, final_dst) + except tarfile.ExtractError: + # chown/chmod/mkfifo/mknode/makedev failed + pass + + return True extraction_drivers = unpack_directory, unpack_zipfile, unpack_tarfile diff --git a/setuptools/build_meta.py b/setuptools/build_meta.py index b9e8a2b3..9dfb2f24 100644 --- a/setuptools/build_meta.py +++ b/setuptools/build_meta.py @@ -101,7 +101,12 @@ def _file_with_extension(directory, extension): f for f in os.listdir(directory) if f.endswith(extension) ) - file, = matching + try: + file, = matching + except ValueError: + raise ValueError( + 'No distribution was found. Ensure that `setup.py` ' + 'is not empty and that it calls `setup()`.') return file diff --git a/setuptools/cli-arm64.exe b/setuptools/cli-arm64.exe Binary files differnew file mode 100644 index 00000000..7a87ce48 --- /dev/null +++ b/setuptools/cli-arm64.exe diff --git a/setuptools/command/__init__.py b/setuptools/command/__init__.py index 743f5588..b966dcea 100644 --- a/setuptools/command/__init__.py +++ b/setuptools/command/__init__.py @@ -1,15 +1,6 @@ -__all__ = [ - 'alias', 'bdist_egg', 'bdist_rpm', 'build_ext', 'build_py', 'develop', - 'easy_install', 'egg_info', 'install', 'install_lib', 'rotate', 'saveopts', - 'sdist', 'setopt', 'test', 'install_egg_info', 'install_scripts', - 'bdist_wininst', 'upload_docs', 'build_clib', 'dist_info', -] - from distutils.command.bdist import bdist import sys -from setuptools.command import install_scripts - if 'egg' not in bdist.format_commands: bdist.format_command['egg'] = ('bdist_egg', "Python .egg file") bdist.format_commands.append('egg') diff --git a/setuptools/command/bdist_egg.py b/setuptools/command/bdist_egg.py index a88efb45..e6b1609f 100644 --- a/setuptools/command/bdist_egg.py +++ b/setuptools/command/bdist_egg.py @@ -2,7 +2,6 @@ Build .egg distributions""" -from distutils.errors import DistutilsSetupError from distutils.dir_util import remove_tree, mkpath from distutils import log from types import CodeType @@ -11,12 +10,10 @@ import os import re import textwrap import marshal -import warnings from pkg_resources import get_build_platform, Distribution, ensure_directory -from pkg_resources import EntryPoint from setuptools.extension import Library -from setuptools import Command, SetuptoolsDeprecationWarning +from setuptools import Command from sysconfig import get_path, get_python_version @@ -153,7 +150,7 @@ class bdist_egg(Command): self.run_command(cmdname) return cmd - def run(self): + def run(self): # noqa: C901 # is too complex (14) # FIXME # Generate metadata first self.run_command("egg_info") # We run install_lib before install_data, because some data hacks @@ -268,49 +265,7 @@ class bdist_egg(Command): return analyze_egg(self.bdist_dir, self.stubs) def gen_header(self): - epm = EntryPoint.parse_map(self.distribution.entry_points or '') - ep = epm.get('setuptools.installation', {}).get('eggsecutable') - if ep is None: - return 'w' # not an eggsecutable, do it the usual way. - - warnings.warn( - "Eggsecutables are deprecated and will be removed in a future " - "version.", - SetuptoolsDeprecationWarning - ) - - if not ep.attrs or ep.extras: - raise DistutilsSetupError( - "eggsecutable entry point (%r) cannot have 'extras' " - "or refer to a module" % (ep,) - ) - - pyver = '{}.{}'.format(*sys.version_info) - pkg = ep.module_name - full = '.'.join(ep.attrs) - base = ep.attrs[0] - basename = os.path.basename(self.egg_output) - - header = ( - "#!/bin/sh\n" - 'if [ `basename $0` = "%(basename)s" ]\n' - 'then exec python%(pyver)s -c "' - "import sys, os; sys.path.insert(0, os.path.abspath('$0')); " - "from %(pkg)s import %(base)s; sys.exit(%(full)s())" - '" "$@"\n' - 'else\n' - ' echo $0 is not the correct name for this egg file.\n' - ' echo Please rename it back to %(basename)s and try again.\n' - ' exec false\n' - 'fi\n' - ) % locals() - - if not self.dry_run: - mkpath(os.path.dirname(self.egg_output), dry_run=self.dry_run) - f = open(self.egg_output, 'w') - f.write(header) - f.close() - return 'a' + return 'w' def copy_metadata_to(self, target_dir): "Copy metadata (egg info) to the target_dir" diff --git a/setuptools/command/bdist_rpm.py b/setuptools/command/bdist_rpm.py index 0eb1b9c2..98bf5dea 100644 --- a/setuptools/command/bdist_rpm.py +++ b/setuptools/command/bdist_rpm.py @@ -1,4 +1,7 @@ import distutils.command.bdist_rpm as orig +import warnings + +from setuptools import SetuptoolsDeprecationWarning class bdist_rpm(orig.bdist_rpm): @@ -11,6 +14,12 @@ class bdist_rpm(orig.bdist_rpm): """ def run(self): + warnings.warn( + "bdist_rpm is deprecated and will be removed in a future " + "version. Use bdist_wheel (wheel packages) instead.", + SetuptoolsDeprecationWarning, + ) + # ensure distro name is up-to-date self.run_command('egg_info') diff --git a/setuptools/command/bdist_wininst.py b/setuptools/command/bdist_wininst.py deleted file mode 100644 index ff4b6345..00000000 --- a/setuptools/command/bdist_wininst.py +++ /dev/null @@ -1,30 +0,0 @@ -import distutils.command.bdist_wininst as orig -import warnings - -from setuptools import SetuptoolsDeprecationWarning - - -class bdist_wininst(orig.bdist_wininst): - def reinitialize_command(self, command, reinit_subcommands=0): - """ - Supplement reinitialize_command to work around - http://bugs.python.org/issue20819 - """ - cmd = self.distribution.reinitialize_command( - command, reinit_subcommands) - if command in ('install', 'install_lib'): - cmd.install_lib = None - return cmd - - def run(self): - warnings.warn( - "bdist_wininst is deprecated and will be removed in a future " - "version. Use bdist_wheel (wheel packages) instead.", - SetuptoolsDeprecationWarning - ) - - self._is_running = True - try: - orig.bdist_wininst.run(self) - finally: - self._is_running = False diff --git a/setuptools/command/build_ext.py b/setuptools/command/build_ext.py index 03a72b4f..c59eff8b 100644 --- a/setuptools/command/build_ext.py +++ b/setuptools/command/build_ext.py @@ -104,14 +104,20 @@ class build_ext(_build_ext): self.write_stub(package_dir or os.curdir, ext, True) def get_ext_filename(self, fullname): - filename = _build_ext.get_ext_filename(self, fullname) + so_ext = os.getenv('SETUPTOOLS_EXT_SUFFIX') + if so_ext: + filename = os.path.join(*fullname.split('.')) + so_ext + else: + filename = _build_ext.get_ext_filename(self, fullname) + so_ext = get_config_var('EXT_SUFFIX') + if fullname in self.ext_map: ext = self.ext_map[fullname] use_abi3 = getattr(ext, 'py_limited_api') and get_abi3_suffix() if use_abi3: - so_ext = get_config_var('EXT_SUFFIX') filename = filename[:-len(so_ext)] - filename = filename + get_abi3_suffix() + so_ext = get_abi3_suffix() + filename = filename + so_ext if isinstance(ext, Library): fn, ext = os.path.splitext(filename) return self.shlib_compiler.library_filename(fn, libtype) diff --git a/setuptools/command/build_py.py b/setuptools/command/build_py.py index b30aa129..6a615433 100644 --- a/setuptools/command/build_py.py +++ b/setuptools/command/build_py.py @@ -8,21 +8,14 @@ import io import distutils.errors import itertools import stat - -try: - from setuptools.lib2to3_ex import Mixin2to3 -except Exception: - - class Mixin2to3: - def run_2to3(self, files, doctests=True): - "do nothing" +from setuptools.extern.more_itertools import unique_everseen def make_writable(target): os.chmod(target, os.stat(target).st_mode | stat.S_IWRITE) -class build_py(orig.build_py, Mixin2to3): +class build_py(orig.build_py): """Enhanced 'build_py' command that includes data files with packages The data files are specified via a 'package_data' argument to 'setup()'. @@ -35,12 +28,10 @@ class build_py(orig.build_py, Mixin2to3): def finalize_options(self): orig.build_py.finalize_options(self) self.package_data = self.distribution.package_data - self.exclude_package_data = (self.distribution.exclude_package_data or - {}) + self.exclude_package_data = self.distribution.exclude_package_data or {} if 'data_files' in self.__dict__: del self.__dict__['data_files'] self.__updated_files = [] - self.__doctests_2to3 = [] def run(self): """Build modules, packages, and copy data files to build directory""" @@ -54,10 +45,6 @@ class build_py(orig.build_py, Mixin2to3): self.build_packages() self.build_package_data() - self.run_2to3(self.__updated_files, False) - self.run_2to3(self.__updated_files, True) - self.run_2to3(self.__doctests_2to3, True) - # Only compile actual .py files, using our base class' idea of what our # output files are. self.byte_compile(orig.build_py.get_outputs(self, include_bytecode=0)) @@ -70,8 +57,7 @@ class build_py(orig.build_py, Mixin2to3): return orig.build_py.__getattr__(self, attr) def build_module(self, module, module_file, package): - outfile, copied = orig.build_py.build_module(self, module, module_file, - package) + outfile, copied = orig.build_py.build_module(self, module, module_file, package) if copied: self.__updated_files.append(outfile) return outfile, copied @@ -122,9 +108,6 @@ class build_py(orig.build_py, Mixin2to3): outf, copied = self.copy_file(srcfile, target) make_writable(target) srcfile = os.path.abspath(srcfile) - if (copied and - srcfile in self.distribution.convert_2to3_doctests): - self.__doctests_2to3.append(outf) def analyze_manifest(self): self.manifest_files = mf = {} @@ -201,20 +184,13 @@ class build_py(orig.build_py, Mixin2to3): package, src_dir, ) - match_groups = ( - fnmatch.filter(files, pattern) - for pattern in patterns - ) + match_groups = (fnmatch.filter(files, pattern) for pattern in patterns) # flatten the groups of matches into an iterable of matches matches = itertools.chain.from_iterable(match_groups) bad = set(matches) - keepers = ( - fn - for fn in files - if fn not in bad - ) + keepers = (fn for fn in files if fn not in bad) # ditch dupes - return list(_unique_everseen(keepers)) + return list(unique_everseen(keepers)) @staticmethod def _get_platform_patterns(spec, package, src_dir): @@ -235,36 +211,22 @@ class build_py(orig.build_py, Mixin2to3): ) -# from Python docs -def _unique_everseen(iterable, key=None): - "List unique elements, preserving order. Remember all elements ever seen." - # unique_everseen('AAAABBBCCDAABBB') --> A B C D - # unique_everseen('ABBCcAD', str.lower) --> A B C D - seen = set() - seen_add = seen.add - if key is None: - for element in itertools.filterfalse(seen.__contains__, iterable): - seen_add(element) - yield element - else: - for element in iterable: - k = key(element) - if k not in seen: - seen_add(k) - yield element - - def assert_relative(path): if not os.path.isabs(path): return path from distutils.errors import DistutilsSetupError - msg = textwrap.dedent(""" + msg = ( + textwrap.dedent( + """ Error: setup script specifies an absolute path: %s setup() arguments must *always* be /-separated paths relative to the setup.py directory, *never* absolute paths. - """).lstrip() % path + """ + ).lstrip() + % path + ) raise DistutilsSetupError(msg) diff --git a/setuptools/command/develop.py b/setuptools/command/develop.py index faf8c988..24fb0a7c 100644 --- a/setuptools/command/develop.py +++ b/setuptools/command/develop.py @@ -63,7 +63,8 @@ class develop(namespaces.DevelopInstaller, easy_install): target = pkg_resources.normalize_path(self.egg_base) egg_path = pkg_resources.normalize_path( - os.path.join(self.install_dir, self.egg_path)) + os.path.join(self.install_dir, self.egg_path) + ) if egg_path != target: raise DistutilsOptionError( "--egg-path must be a relative path from the install" @@ -74,7 +75,7 @@ class develop(namespaces.DevelopInstaller, easy_install): self.dist = pkg_resources.Distribution( target, pkg_resources.PathMetadata(target, os.path.abspath(ei.egg_info)), - project_name=ei.egg_name + project_name=ei.egg_name, ) self.setup_path = self._resolve_setup_path( @@ -99,41 +100,18 @@ class develop(namespaces.DevelopInstaller, easy_install): if resolved != pkg_resources.normalize_path(os.curdir): raise DistutilsOptionError( "Can't get a consistent path to setup script from" - " installation directory", resolved, - pkg_resources.normalize_path(os.curdir)) + " installation directory", + resolved, + pkg_resources.normalize_path(os.curdir), + ) return path_to_setup def install_for_development(self): - if getattr(self.distribution, 'use_2to3', False): - # If we run 2to3 we can not do this inplace: - - # Ensure metadata is up-to-date - self.reinitialize_command('build_py', inplace=0) - self.run_command('build_py') - bpy_cmd = self.get_finalized_command("build_py") - build_path = pkg_resources.normalize_path(bpy_cmd.build_lib) - - # Build extensions - self.reinitialize_command('egg_info', egg_base=build_path) - self.run_command('egg_info') - - self.reinitialize_command('build_ext', inplace=0) - self.run_command('build_ext') - - # Fixup egg-link and easy-install.pth - ei_cmd = self.get_finalized_command("egg_info") - self.egg_path = build_path - self.dist.location = build_path - # XXX - self.dist._provider = pkg_resources.PathMetadata( - build_path, ei_cmd.egg_info) - else: - # Without 2to3 inplace works fine: - self.run_command('egg_info') + self.run_command('egg_info') - # Build extensions in-place - self.reinitialize_command('build_ext', inplace=1) - self.run_command('build_ext') + # Build extensions in-place + self.reinitialize_command('build_ext', inplace=1) + self.run_command('build_ext') if setuptools.bootstrap_install_from: self.easy_install(setuptools.bootstrap_install_from) @@ -156,8 +134,7 @@ class develop(namespaces.DevelopInstaller, easy_install): egg_link_file = open(self.egg_link) contents = [line.rstrip() for line in egg_link_file] egg_link_file.close() - if contents not in ([self.egg_path], - [self.egg_path, self.setup_path]): + if contents not in ([self.egg_path], [self.egg_path, self.setup_path]): log.warn("Link points to %s: uninstall aborted", contents) return if not self.dry_run: diff --git a/setuptools/command/easy_install.py b/setuptools/command/easy_install.py index 9ec83b7d..b88c3e9a 100644 --- a/setuptools/command/easy_install.py +++ b/setuptools/command/easy_install.py @@ -6,7 +6,7 @@ A tool for doing automatic download/extract/build of distutils-based Python packages. For detailed documentation, see the accompanying EasyInstall.txt file, or visit the `EasyInstall home page`__. -__ https://setuptools.readthedocs.io/en/latest/easy_install.html +__ https://setuptools.readthedocs.io/en/latest/deprecated/easy_install.html """ @@ -67,7 +67,7 @@ warnings.filterwarnings("default", category=pkg_resources.PEP440Warning) __all__ = [ 'samefile', 'easy_install', 'PthDistributions', 'extract_wininst_cfg', - 'main', 'get_exe_prefixes', + 'get_exe_prefixes', ] @@ -226,7 +226,7 @@ class easy_install(Command): print(tmpl.format(**locals())) raise SystemExit() - def finalize_options(self): + def finalize_options(self): # noqa: C901 # is too complex (25) # FIXME self.version and self._render_version() py_version = sys.version.split()[0] @@ -437,7 +437,7 @@ class easy_install(Command): def warn_deprecated_options(self): pass - def check_site_dir(self): + def check_site_dir(self): # noqa: C901 # is too complex (12) # FIXME """Verify that self.install_dir is .pth-capable dir, if needed""" instdir = normalize_path(self.install_dir) @@ -513,7 +513,7 @@ class easy_install(Command): For information on other options, you may wish to consult the documentation at: - https://setuptools.readthedocs.io/en/latest/easy_install.html + https://setuptools.readthedocs.io/en/latest/deprecated/easy_install.html Please make the appropriate changes for your system and try again. """).lstrip() # noqa @@ -713,7 +713,10 @@ class easy_install(Command): if getattr(self, attrname) is None: setattr(self, attrname, scheme[key]) - def process_distribution(self, requirement, dist, deps=True, *info): + # FIXME: 'easy_install.process_distribution' is too complex (12) + def process_distribution( # noqa: C901 + self, requirement, dist, deps=True, *info, + ): self.update_pth(dist) self.package_index.add(dist) if dist in self.local_index[dist.key]: @@ -837,12 +840,19 @@ class easy_install(Command): def install_eggs(self, spec, dist_filename, tmpdir): # .egg dirs or files are already built, so just return them - if dist_filename.lower().endswith('.egg'): - return [self.install_egg(dist_filename, tmpdir)] - elif dist_filename.lower().endswith('.exe'): - return [self.install_exe(dist_filename, tmpdir)] - elif dist_filename.lower().endswith('.whl'): - return [self.install_wheel(dist_filename, tmpdir)] + installer_map = { + '.egg': self.install_egg, + '.exe': self.install_exe, + '.whl': self.install_wheel, + } + try: + install_dist = installer_map[ + dist_filename.lower()[-4:] + ] + except KeyError: + pass + else: + return [install_dist(dist_filename, tmpdir)] # Anything else, try to extract and build setup_base = tmpdir @@ -887,7 +897,8 @@ class easy_install(Command): metadata = EggMetadata(zipimport.zipimporter(egg_path)) return Distribution.from_filename(egg_path, metadata=metadata) - def install_egg(self, egg_path, tmpdir): + # FIXME: 'easy_install.install_egg' is too complex (11) + def install_egg(self, egg_path, tmpdir): # noqa: C901 destination = os.path.join( self.install_dir, os.path.basename(egg_path), @@ -986,7 +997,8 @@ class easy_install(Command): # install the .egg return self.install_egg(egg_path, tmpdir) - def exe_to_egg(self, dist_filename, egg_tmp): + # FIXME: 'easy_install.exe_to_egg' is too complex (12) + def exe_to_egg(self, dist_filename, egg_tmp): # noqa: C901 """Extract a bdist_wininst to the directories an egg would use""" # Check for .pth file and set up prefix translations prefixes = get_exe_prefixes(dist_filename) @@ -1178,22 +1190,24 @@ class easy_install(Command): for key, val in ei_opts.items(): if key not in fetch_directives: continue - fetch_options[key.replace('_', '-')] = val[1] + fetch_options[key] = val[1] # create a settings dictionary suitable for `edit_config` settings = dict(easy_install=fetch_options) cfg_filename = os.path.join(base, 'setup.cfg') setopt.edit_config(cfg_filename, settings) - def update_pth(self, dist): + def update_pth(self, dist): # noqa: C901 # is too complex (11) # FIXME if self.pth_file is None: return for d in self.pth_file[dist.key]: # drop old entries - if self.multi_version or d.location != dist.location: - log.info("Removing %s from easy-install.pth file", d) - self.pth_file.remove(d) - if d.location in self.shadow_path: - self.shadow_path.remove(d.location) + if not self.multi_version and d.location == dist.location: + continue + + log.info("Removing %s from easy-install.pth file", d) + self.pth_file.remove(d) + if d.location in self.shadow_path: + self.shadow_path.remove(d.location) if not self.multi_version: if dist.location in self.pth_file.paths: @@ -1207,19 +1221,21 @@ class easy_install(Command): if dist.location not in self.shadow_path: self.shadow_path.append(dist.location) - if not self.dry_run: + if self.dry_run: + return - self.pth_file.save() + self.pth_file.save() - if dist.key == 'setuptools': - # Ensure that setuptools itself never becomes unavailable! - # XXX should this check for latest version? - filename = os.path.join(self.install_dir, 'setuptools.pth') - if os.path.islink(filename): - os.unlink(filename) - f = open(filename, 'wt') - f.write(self.pth_file.make_relative(dist.location) + '\n') - f.close() + if dist.key != 'setuptools': + return + + # Ensure that setuptools itself never becomes unavailable! + # XXX should this check for latest version? + filename = os.path.join(self.install_dir, 'setuptools.pth') + if os.path.islink(filename): + os.unlink(filename) + with open(filename, 'wt') as f: + f.write(self.pth_file.make_relative(dist.location) + '\n') def unpack_progress(self, src, dst): # Progress filter for unpacking @@ -1290,7 +1306,7 @@ class easy_install(Command): * You can set up the installation directory to support ".pth" files by using one of the approaches described here: - https://setuptools.readthedocs.io/en/latest/easy_install.html#custom-installation-locations + https://setuptools.readthedocs.io/en/latest/deprecated/easy_install.html#custom-installation-locations Please make the appropriate changes for your system and try again. @@ -1360,58 +1376,63 @@ def get_site_dirs(): if sys.exec_prefix != sys.prefix: prefixes.append(sys.exec_prefix) for prefix in prefixes: - if prefix: - if sys.platform in ('os2emx', 'riscos'): - sitedirs.append(os.path.join(prefix, "Lib", "site-packages")) - elif os.sep == '/': - sitedirs.extend([ - os.path.join( - prefix, - "lib", - "python{}.{}".format(*sys.version_info), - "site-packages", - ), - os.path.join(prefix, "lib", "site-python"), - ]) - else: - sitedirs.extend([ + if not prefix: + continue + + if sys.platform in ('os2emx', 'riscos'): + sitedirs.append(os.path.join(prefix, "Lib", "site-packages")) + elif os.sep == '/': + sitedirs.extend([ + os.path.join( prefix, - os.path.join(prefix, "lib", "site-packages"), - ]) - if sys.platform == 'darwin': - # for framework builds *only* we add the standard Apple - # locations. Currently only per-user, but /Library and - # /Network/Library could be added too - if 'Python.framework' in prefix: - home = os.environ.get('HOME') - if home: - home_sp = os.path.join( - home, - 'Library', - 'Python', - '{}.{}'.format(*sys.version_info), - 'site-packages', - ) - sitedirs.append(home_sp) + "lib", + "python{}.{}".format(*sys.version_info), + "site-packages", + ), + os.path.join(prefix, "lib", "site-python"), + ]) + else: + sitedirs.extend([ + prefix, + os.path.join(prefix, "lib", "site-packages"), + ]) + if sys.platform != 'darwin': + continue + + # for framework builds *only* we add the standard Apple + # locations. Currently only per-user, but /Library and + # /Network/Library could be added too + if 'Python.framework' not in prefix: + continue + + home = os.environ.get('HOME') + if not home: + continue + + home_sp = os.path.join( + home, + 'Library', + 'Python', + '{}.{}'.format(*sys.version_info), + 'site-packages', + ) + sitedirs.append(home_sp) lib_paths = get_path('purelib'), get_path('platlib') - for site_lib in lib_paths: - if site_lib not in sitedirs: - sitedirs.append(site_lib) + + sitedirs.extend(s for s in lib_paths if s not in sitedirs) if site.ENABLE_USER_SITE: sitedirs.append(site.USER_SITE) - try: + with contextlib.suppress(AttributeError): sitedirs.extend(site.getsitepackages()) - except AttributeError: - pass sitedirs = list(map(normalize_path, sitedirs)) return sitedirs -def expand_paths(inputs): +def expand_paths(inputs): # noqa: C901 # is too complex (11) # FIXME """Yield sys.path directories that might contain "old-style" packages""" seen = {} @@ -1443,13 +1464,18 @@ def expand_paths(inputs): # Yield existing non-dupe, non-import directory lines from it for line in lines: - if not line.startswith("import"): - line = normalize_path(line.rstrip()) - if line not in seen: - seen[line] = 1 - if not os.path.isdir(line): - continue - yield line, os.listdir(line) + if line.startswith("import"): + continue + + line = normalize_path(line.rstrip()) + if line in seen: + continue + + seen[line] = 1 + if not os.path.isdir(line): + continue + + yield line, os.listdir(line) def extract_wininst_cfg(dist_filename): @@ -1482,7 +1508,7 @@ def extract_wininst_cfg(dist_filename): # Now the config is in bytes, but for RawConfigParser, it should # be text, so decode it. config = config.decode(sys.getfilesystemencoding()) - cfg.readfp(io.StringIO(config)) + cfg.read_file(io.StringIO(config)) except configparser.Error: return None if not cfg.has_section('metadata') or not cfg.has_section('Setup'): @@ -2167,7 +2193,7 @@ class WindowsScriptWriter(ScriptWriter): @classmethod def _adjust_header(cls, type_, orig_header): """ - Make sure 'pythonw' is used for gui and and 'python' is used for + Make sure 'pythonw' is used for gui and 'python' is used for console (regardless of what sys.executable is). """ pattern = 'pythonw.exe' @@ -2237,7 +2263,10 @@ def get_win_launcher(type): """ launcher_fn = '%s.exe' % type if is_64bit(): - launcher_fn = launcher_fn.replace(".", "-64.") + if get_platform() == "win-arm64": + launcher_fn = launcher_fn.replace(".", "-arm64.") + else: + launcher_fn = launcher_fn.replace(".", "-64.") else: launcher_fn = launcher_fn.replace(".", "-32.") return resource_string('setuptools', launcher_fn) @@ -2258,60 +2287,6 @@ def current_umask(): return tmp -def bootstrap(): - # This function is called when setuptools*.egg is run using /bin/sh - import setuptools - - argv0 = os.path.dirname(setuptools.__path__[0]) - sys.argv[0] = argv0 - sys.argv.append(argv0) - main() - - -def main(argv=None, **kw): - from setuptools import setup - from setuptools.dist import Distribution - - class DistributionWithoutHelpCommands(Distribution): - common_usage = "" - - def _show_help(self, *args, **kw): - with _patch_usage(): - Distribution._show_help(self, *args, **kw) - - if argv is None: - argv = sys.argv[1:] - - with _patch_usage(): - setup( - script_args=['-q', 'easy_install', '-v'] + argv, - script_name=sys.argv[0] or 'easy_install', - distclass=DistributionWithoutHelpCommands, - **kw - ) - - -@contextlib.contextmanager -def _patch_usage(): - import distutils.core - USAGE = textwrap.dedent(""" - usage: %(script)s [options] requirement_or_url ... - or: %(script)s --help - """).lstrip() - - def gen_usage(script_name): - return USAGE % dict( - script=os.path.basename(script_name), - ) - - saved = distutils.core.gen_usage - distutils.core.gen_usage = gen_usage - try: - yield - finally: - distutils.core.gen_usage = saved - - class EasyInstallDeprecationWarning(SetuptoolsDeprecationWarning): """ Warning for EasyInstall deprecations, bypassing suppression. diff --git a/setuptools/command/egg_info.py b/setuptools/command/egg_info.py index 97e10d99..57bc7982 100644 --- a/setuptools/command/egg_info.py +++ b/setuptools/command/egg_info.py @@ -8,6 +8,7 @@ from distutils.util import convert_path from distutils import log import distutils.errors import distutils.filelist +import functools import os import re import sys @@ -31,7 +32,7 @@ from setuptools.extern import packaging from setuptools import SetuptoolsDeprecationWarning -def translate_pattern(glob): +def translate_pattern(glob): # noqa: C901 # is too complex (14) # FIXME """ Translate a file path glob like '*.txt' in to a regular expression. This differs from fnmatch.translate which allows wildcards to match @@ -332,70 +333,74 @@ class FileList(_FileList): # patterns, (dir and patterns), or (dir_pattern). (action, patterns, dir, dir_pattern) = self._parse_template_line(line) + action_map = { + 'include': self.include, + 'exclude': self.exclude, + 'global-include': self.global_include, + 'global-exclude': self.global_exclude, + 'recursive-include': functools.partial( + self.recursive_include, dir, + ), + 'recursive-exclude': functools.partial( + self.recursive_exclude, dir, + ), + 'graft': self.graft, + 'prune': self.prune, + } + log_map = { + 'include': "warning: no files found matching '%s'", + 'exclude': ( + "warning: no previously-included files found " + "matching '%s'" + ), + 'global-include': ( + "warning: no files found matching '%s' " + "anywhere in distribution" + ), + 'global-exclude': ( + "warning: no previously-included files matching " + "'%s' found anywhere in distribution" + ), + 'recursive-include': ( + "warning: no files found matching '%s' " + "under directory '%s'" + ), + 'recursive-exclude': ( + "warning: no previously-included files matching " + "'%s' found under directory '%s'" + ), + 'graft': "warning: no directories found matching '%s'", + 'prune': "no previously-included directories found matching '%s'", + } + + try: + process_action = action_map[action] + except KeyError: + raise DistutilsInternalError( + "this cannot happen: invalid action '{action!s}'". + format(action=action), + ) + # OK, now we know that the action is valid and we have the # right number of words on the line for that action -- so we # can proceed with minimal error-checking. - if action == 'include': - self.debug_print("include " + ' '.join(patterns)) - for pattern in patterns: - if not self.include(pattern): - log.warn("warning: no files found matching '%s'", pattern) - - elif action == 'exclude': - self.debug_print("exclude " + ' '.join(patterns)) - for pattern in patterns: - if not self.exclude(pattern): - log.warn(("warning: no previously-included files " - "found matching '%s'"), pattern) - - elif action == 'global-include': - self.debug_print("global-include " + ' '.join(patterns)) - for pattern in patterns: - if not self.global_include(pattern): - log.warn(("warning: no files found matching '%s' " - "anywhere in distribution"), pattern) - - elif action == 'global-exclude': - self.debug_print("global-exclude " + ' '.join(patterns)) - for pattern in patterns: - if not self.global_exclude(pattern): - log.warn(("warning: no previously-included files matching " - "'%s' found anywhere in distribution"), - pattern) - - elif action == 'recursive-include': - self.debug_print("recursive-include %s %s" % - (dir, ' '.join(patterns))) - for pattern in patterns: - if not self.recursive_include(dir, pattern): - log.warn(("warning: no files found matching '%s' " - "under directory '%s'"), - pattern, dir) - - elif action == 'recursive-exclude': - self.debug_print("recursive-exclude %s %s" % - (dir, ' '.join(patterns))) - for pattern in patterns: - if not self.recursive_exclude(dir, pattern): - log.warn(("warning: no previously-included files matching " - "'%s' found under directory '%s'"), - pattern, dir) - - elif action == 'graft': - self.debug_print("graft " + dir_pattern) - if not self.graft(dir_pattern): - log.warn("warning: no directories found matching '%s'", - dir_pattern) - - elif action == 'prune': - self.debug_print("prune " + dir_pattern) - if not self.prune(dir_pattern): - log.warn(("no previously-included directories found " - "matching '%s'"), dir_pattern) - - else: - raise DistutilsInternalError( - "this cannot happen: invalid action '%s'" % action) + + action_is_recursive = action.startswith('recursive-') + if action in {'graft', 'prune'}: + patterns = [dir_pattern] + extra_log_args = (dir, ) if action_is_recursive else () + log_tmpl = log_map[action] + + self.debug_print( + ' '.join( + [action] + + ([dir] if action_is_recursive else []) + + patterns, + ) + ) + for pattern in patterns: + if not process_action(pattern): + log.warn(log_tmpl, pattern, *extra_log_args) def _remove_files(self, predicate): """ @@ -536,6 +541,7 @@ class manifest_maker(sdist): self.add_defaults() if os.path.exists(self.template): self.read_template() + self.add_license_files() self.prune_file_list() self.filelist.sort() self.filelist.remove_duplicates() @@ -570,7 +576,6 @@ class manifest_maker(sdist): def add_defaults(self): sdist.add_defaults(self) - self.check_license() self.filelist.append(self.template) self.filelist.append(self.manifest) rcfiles = list(walk_revctrl()) @@ -587,6 +592,13 @@ class manifest_maker(sdist): ei_cmd = self.get_finalized_command('egg_info') self.filelist.graft(ei_cmd.egg_info) + def add_license_files(self): + license_files = self.distribution.metadata.license_files or [] + for lf in license_files: + log.info("adding license file '%s'", lf) + pass + self.filelist.extend(license_files) + def prune_file_list(self): build = self.get_finalized_command('build') base_dir = self.distribution.get_fullname() diff --git a/setuptools/command/install_scripts.py b/setuptools/command/install_scripts.py index 8c9a15e2..9cd8eb06 100644 --- a/setuptools/command/install_scripts.py +++ b/setuptools/command/install_scripts.py @@ -1,5 +1,6 @@ from distutils import log import distutils.command.install_scripts as orig +from distutils.errors import DistutilsModuleError import os import sys @@ -35,7 +36,7 @@ class install_scripts(orig.install_scripts): try: bw_cmd = self.get_finalized_command("bdist_wininst") is_wininst = getattr(bw_cmd, '_is_running', False) - except ImportError: + except (ImportError, DistutilsModuleError): is_wininst = False writer = ei.ScriptWriter if is_wininst: diff --git a/setuptools/command/sdist.py b/setuptools/command/sdist.py index 887b7efa..e8062f2e 100644 --- a/setuptools/command/sdist.py +++ b/setuptools/command/sdist.py @@ -5,8 +5,6 @@ import sys import io import contextlib -from setuptools.extern import ordered_set - from .py36compat import sdist_add_defaults import pkg_resources @@ -33,6 +31,10 @@ class sdist(sdist_add_defaults, orig.sdist): ('dist-dir=', 'd', "directory to put the source distribution archive(s) in " "[default: dist]"), + ('owner=', 'u', + "Owner name used when creating a tar file [default: current user]"), + ('group=', 'g', + "Group name used when creating a tar file [default: current group]"), ] negative_opt = {} @@ -189,34 +191,3 @@ class sdist(sdist_add_defaults, orig.sdist): continue self.filelist.append(line) manifest.close() - - def check_license(self): - """Checks if license_file' or 'license_files' is configured and adds any - valid paths to 'self.filelist'. - """ - - files = ordered_set.OrderedSet() - - opts = self.distribution.get_option_dict('metadata') - - # ignore the source of the value - _, license_file = opts.get('license_file', (None, None)) - - if license_file is None: - log.debug("'license_file' option was not specified") - else: - files.add(license_file) - - try: - files.update(self.distribution.metadata.license_files) - except TypeError: - log.warn("warning: 'license_files' option is malformed") - - for f in files: - if not os.path.exists(f): - log.warn( - "warning: Failed to find the configured license file '%s'", - f) - files.remove(f) - - self.filelist.extend(files) diff --git a/setuptools/command/setopt.py b/setuptools/command/setopt.py index e18057c8..6358c045 100644 --- a/setuptools/command/setopt.py +++ b/setuptools/command/setopt.py @@ -39,6 +39,7 @@ def edit_config(filename, settings, dry_run=False): """ log.debug("Reading configuration from %s", filename) opts = configparser.RawConfigParser() + opts.optionxform = lambda x: x opts.read([filename]) for section, options in settings.items(): if options is None: diff --git a/setuptools/command/test.py b/setuptools/command/test.py index cf71ad01..4a389e4d 100644 --- a/setuptools/command/test.py +++ b/setuptools/command/test.py @@ -8,15 +8,21 @@ from distutils.errors import DistutilsError, DistutilsOptionError from distutils import log from unittest import TestLoader -from pkg_resources import (resource_listdir, resource_exists, normalize_path, - working_set, _namespace_packages, evaluate_marker, - add_activation_listener, require, EntryPoint) +from pkg_resources import ( + resource_listdir, + resource_exists, + normalize_path, + working_set, + evaluate_marker, + add_activation_listener, + require, + EntryPoint, +) from setuptools import Command -from .build_py import _unique_everseen +from setuptools.extern.more_itertools import unique_everseen class ScanningLoader(TestLoader): - def __init__(self): TestLoader.__init__(self) self._visited = set() @@ -73,8 +79,11 @@ class test(Command): user_options = [ ('test-module=', 'm', "Run 'test_suite' in specified module"), - ('test-suite=', 's', - "Run single test, case or suite (e.g. 'module.test_suite')"), + ( + 'test-suite=', + 's', + "Run single test, case or suite (e.g. 'module.test_suite')", + ), ('test-runner=', 'r', "Test runner to use"), ] @@ -124,30 +133,11 @@ class test(Command): @contextlib.contextmanager def project_on_sys_path(self, include_dists=[]): - with_2to3 = getattr(self.distribution, 'use_2to3', False) - - if with_2to3: - # If we run 2to3 we can not do this inplace: + self.run_command('egg_info') - # Ensure metadata is up-to-date - self.reinitialize_command('build_py', inplace=0) - self.run_command('build_py') - bpy_cmd = self.get_finalized_command("build_py") - build_path = normalize_path(bpy_cmd.build_lib) - - # Build extensions - self.reinitialize_command('egg_info', egg_base=build_path) - self.run_command('egg_info') - - self.reinitialize_command('build_ext', inplace=0) - self.run_command('build_ext') - else: - # Without 2to3 inplace works fine: - self.run_command('egg_info') - - # Build extensions in-place - self.reinitialize_command('build_ext', inplace=1) - self.run_command('build_ext') + # Build extensions in-place + self.reinitialize_command('build_ext', inplace=1) + self.run_command('build_ext') ei_cmd = self.get_finalized_command("egg_info") @@ -182,7 +172,7 @@ class test(Command): orig_pythonpath = os.environ.get('PYTHONPATH', nothing) current_pythonpath = os.environ.get('PYTHONPATH', '') try: - prefix = os.pathsep.join(_unique_everseen(paths)) + prefix = os.pathsep.join(unique_everseen(paths)) to_join = filter(None, [prefix, current_pythonpath]) new_path = os.pathsep.join(to_join) if new_path: @@ -203,7 +193,8 @@ class test(Command): ir_d = dist.fetch_build_eggs(dist.install_requires) tr_d = dist.fetch_build_eggs(dist.tests_require or []) er_d = dist.fetch_build_eggs( - v for k, v in dist.extras_require.items() + v + for k, v in dist.extras_require.items() if k.startswith(':') and evaluate_marker(k[1:]) ) return itertools.chain(ir_d, tr_d, er_d) @@ -232,23 +223,10 @@ class test(Command): self.run_tests() def run_tests(self): - # Purge modules under test from sys.modules. The test loader will - # re-import them from the build location. Required when 2to3 is used - # with namespace packages. - if getattr(self.distribution, 'use_2to3', False): - module = self.test_suite.split('.')[0] - if module in _namespace_packages: - del_modules = [] - if module in sys.modules: - del_modules.append(module) - module += '.' - for name in sys.modules: - if name.startswith(module): - del_modules.append(name) - list(map(sys.modules.__delitem__, del_modules)) - test = unittest.main( - None, None, self._argv, + None, + None, + self._argv, testLoader=self._resolve_as_ep(self.test_loader), testRunner=self._resolve_as_ep(self.test_runner), exit=False, diff --git a/setuptools/command/upload_docs.py b/setuptools/command/upload_docs.py index 2559458a..845bff44 100644 --- a/setuptools/command/upload_docs.py +++ b/setuptools/command/upload_docs.py @@ -2,7 +2,7 @@ """upload_docs Implements a Distutils 'upload_docs' subcommand (upload documentation to -PyPI's pythonhosted.org). +sites other than PyPi such as devpi). """ from base64 import standard_b64encode @@ -31,7 +31,7 @@ class upload_docs(upload): # supported by Warehouse (and won't be). DEFAULT_REPOSITORY = 'https://pypi.python.org/pypi/' - description = 'Upload documentation to PyPI' + description = 'Upload documentation to sites other than PyPi such as devpi' user_options = [ ('repository=', 'r', @@ -59,7 +59,7 @@ class upload_docs(upload): if self.upload_dir is None: if self.has_sphinx(): build_sphinx = self.get_finalized_command('build_sphinx') - self.target_dir = build_sphinx.builder_target_dir + self.target_dir = dict(build_sphinx.builder_target_dirs)['html'] else: build = self.get_finalized_command('build') self.target_dir = os.path.join(build.build_base, 'docs') @@ -67,7 +67,7 @@ class upload_docs(upload): self.ensure_dirname('upload_dir') self.target_dir = self.upload_dir if 'pypi.python.org' in self.repository: - log.warn("Upload_docs command is deprecated. Use RTD instead.") + log.warn("Upload_docs command is deprecated for PyPi. Use RTD instead.") self.announce('Using upload directory %s' % self.target_dir) def create_zipfile(self, filename): diff --git a/setuptools/config.py b/setuptools/config.py index af3a3bcb..e3e44c25 100644 --- a/setuptools/config.py +++ b/setuptools/config.py @@ -9,6 +9,7 @@ import importlib from collections import defaultdict from functools import partial from functools import wraps +from glob import iglob import contextlib from distutils.errors import DistutilsOptionError, DistutilsFileError @@ -20,6 +21,7 @@ class StaticModule: """ Attempt to load the module by the name """ + def __init__(self, name): spec = importlib.util.find_spec(name) with open(spec.origin) as strm: @@ -55,8 +57,7 @@ def patch_path(path): sys.path.remove(path) -def read_configuration( - filepath, find_others=False, ignore_option_errors=False): +def read_configuration(filepath, find_others=False, ignore_option_errors=False): """Read given configuration file and returns options from it as a dict. :param str|unicode filepath: Path to configuration file @@ -77,8 +78,7 @@ def read_configuration( filepath = os.path.abspath(filepath) if not os.path.isfile(filepath): - raise DistutilsFileError( - 'Configuration file %s does not exist.' % filepath) + raise DistutilsFileError('Configuration file %s does not exist.' % filepath) current_directory = os.getcwd() os.chdir(os.path.dirname(filepath)) @@ -93,8 +93,8 @@ def read_configuration( _Distribution.parse_config_files(dist, filenames=filenames) handlers = parse_configuration( - dist, dist.command_options, - ignore_option_errors=ignore_option_errors) + dist, dist.command_options, ignore_option_errors=ignore_option_errors + ) finally: os.chdir(current_directory) @@ -132,8 +132,7 @@ def configuration_to_dict(handlers): return config_dict -def parse_configuration( - distribution, command_options, ignore_option_errors=False): +def parse_configuration(distribution, command_options, ignore_option_errors=False): """Performs additional parsing of configuration options for a distribution. @@ -147,13 +146,15 @@ def parse_configuration( If False exceptions are propagated as expected. :rtype: list """ - options = ConfigOptionsHandler( - distribution, command_options, ignore_option_errors) + options = ConfigOptionsHandler(distribution, command_options, ignore_option_errors) options.parse() meta = ConfigMetadataHandler( - distribution.metadata, command_options, ignore_option_errors, - distribution.package_dir) + distribution.metadata, + command_options, + ignore_option_errors, + distribution.package_dir, + ) meta.parse() return meta, options @@ -195,7 +196,8 @@ class ConfigHandler: def parsers(self): """Metadata item name to parser function mapping.""" raise NotImplementedError( - '%s must provide .parsers property' % self.__class__.__name__) + '%s must provide .parsers property' % self.__class__.__name__ + ) def __setitem__(self, option_name, value): unknown = tuple() @@ -256,6 +258,34 @@ class ConfigHandler: return [chunk.strip() for chunk in value if chunk.strip()] @classmethod + def _parse_list_glob(cls, value, separator=','): + """Equivalent to _parse_list() but expands any glob patterns using glob(). + + However, unlike with glob() calls, the results remain relative paths. + + :param value: + :param separator: List items separator character. + :rtype: list + """ + glob_characters = ('*', '?', '[', ']', '{', '}') + values = cls._parse_list(value, separator=separator) + expanded_values = [] + for value in values: + + # Has globby characters? + if any(char in value for char in glob_characters): + # then expand the glob pattern while keeping paths *relative*: + expanded_values.extend(sorted( + os.path.relpath(path, os.getcwd()) + for path in iglob(os.path.abspath(value)))) + + else: + # take the value as-is: + expanded_values.append(value) + + return expanded_values + + @classmethod def _parse_dict(cls, value): """Represents value as a dict. @@ -268,7 +298,8 @@ class ConfigHandler: key, sep, val = line.partition(separator) if sep != separator: raise DistutilsOptionError( - 'Unable to parse option value to dict: %s' % value) + 'Unable to parse option value to dict: %s' % value + ) result[key.strip()] = val.strip() return result @@ -294,13 +325,16 @@ class ConfigHandler: :param key: :rtype: callable """ + def parser(value): exclude_directive = 'file:' if value.startswith(exclude_directive): raise ValueError( 'Only strings are accepted for the {0} field, ' - 'files are not accepted'.format(key)) + 'files are not accepted'.format(key) + ) return value + return parser @classmethod @@ -325,20 +359,18 @@ class ConfigHandler: if not value.startswith(include_directive): return value - spec = value[len(include_directive):] + spec = value[len(include_directive) :] filepaths = (os.path.abspath(path.strip()) for path in spec.split(',')) return '\n'.join( cls._read_file(path) for path in filepaths - if (cls._assert_local(path) or True) - and os.path.isfile(path) + if (cls._assert_local(path) or True) and os.path.isfile(path) ) @staticmethod def _assert_local(filepath): if not filepath.startswith(os.getcwd()): - raise DistutilsOptionError( - '`file:` directive can not access %s' % filepath) + raise DistutilsOptionError('`file:` directive can not access %s' % filepath) @staticmethod def _read_file(filepath): @@ -400,6 +432,7 @@ class ConfigHandler: :param parse_methods: :rtype: callable """ + def parse(value): parsed = value @@ -453,22 +486,25 @@ class ConfigHandler: self, # Dots in section names are translated into dunderscores. ('parse_section%s' % method_postfix).replace('.', '__'), - None) + None, + ) if section_parser_method is None: raise DistutilsOptionError( - 'Unsupported distribution option section: [%s.%s]' % ( - self.section_prefix, section_name)) + 'Unsupported distribution option section: [%s.%s]' + % (self.section_prefix, section_name) + ) section_parser_method(section_options) def _deprecated_config_handler(self, func, msg, warning_class): - """ this function will wrap around parameters that are deprecated + """this function will wrap around parameters that are deprecated :param msg: deprecation message :param warning_class: class of warning exception to be raised :param func: function to be wrapped around """ + @wraps(func) def config_handler(*args, **kwargs): warnings.warn(msg, warning_class) @@ -494,10 +530,12 @@ class ConfigMetadataHandler(ConfigHandler): """ - def __init__(self, target_obj, options, ignore_option_errors=False, - package_dir=None): - super(ConfigMetadataHandler, self).__init__(target_obj, options, - ignore_option_errors) + def __init__( + self, target_obj, options, ignore_option_errors=False, package_dir=None + ): + super(ConfigMetadataHandler, self).__init__( + target_obj, options, ignore_option_errors + ) self.package_dir = package_dir @property @@ -516,10 +554,17 @@ class ConfigMetadataHandler(ConfigHandler): parse_list, "The requires parameter is deprecated, please use " "install_requires for runtime dependencies.", - DeprecationWarning), + DeprecationWarning, + ), 'obsoletes': parse_list, 'classifiers': self._get_parser_compound(parse_file, parse_list), 'license': exclude_files_parser('license'), + 'license_file': self._deprecated_config_handler( + exclude_files_parser('license_file'), + "The license_file parameter is deprecated, " + "use license_files instead.", + DeprecationWarning, + ), 'license_files': parse_list, 'description': parse_file, 'long_description': parse_file, @@ -574,15 +619,12 @@ class ConfigOptionsHandler(ConfigHandler): parse_list_semicolon = partial(self._parse_list, separator=';') parse_bool = self._parse_bool parse_dict = self._parse_dict + parse_cmdclass = self._parse_cmdclass return { 'zip_safe': parse_bool, - 'use_2to3': parse_bool, 'include_package_data': parse_bool, 'package_dir': parse_dict, - 'use_2to3_fixers': parse_list, - 'use_2to3_exclude_fixers': parse_list, - 'convert_2to3_doctests': parse_list, 'scripts': parse_list, 'eager_resources': parse_list, 'dependency_links': parse_list, @@ -594,8 +636,21 @@ class ConfigOptionsHandler(ConfigHandler): 'entry_points': self._parse_file, 'py_modules': parse_list, 'python_requires': SpecifierSet, + 'cmdclass': parse_cmdclass, } + def _parse_cmdclass(self, value): + def resolve_class(qualified_class_name): + idx = qualified_class_name.rfind('.') + class_name = qualified_class_name[idx + 1 :] + pkg_name = qualified_class_name[:idx] + + module = __import__(pkg_name) + + return getattr(module, class_name) + + return {k: resolve_class(v) for k, v in self._parse_dict(value).items()} + def _parse_packages(self, value): """Parses `packages` option value. @@ -612,7 +667,8 @@ class ConfigOptionsHandler(ConfigHandler): # Read function arguments from a dedicated section. find_kwargs = self.parse_section_packages__find( - self.sections.get('packages.find', {})) + self.sections.get('packages.find', {}) + ) if findns: from setuptools import find_namespace_packages as find_packages @@ -628,13 +684,13 @@ class ConfigOptionsHandler(ConfigHandler): :param dict section_options: """ - section_data = self._parse_section_to_dict( - section_options, self._parse_list) + section_data = self._parse_section_to_dict(section_options, self._parse_list) valid_keys = ['where', 'include', 'exclude'] find_kwargs = dict( - [(k, v) for k, v in section_data.items() if k in valid_keys and v]) + [(k, v) for k, v in section_data.items() if k in valid_keys and v] + ) where = find_kwargs.get('where') if where is not None: @@ -672,8 +728,7 @@ class ConfigOptionsHandler(ConfigHandler): :param dict section_options: """ - self['exclude_package_data'] = self._parse_package_data( - section_options) + self['exclude_package_data'] = self._parse_package_data(section_options) def parse_section_extras_require(self, section_options): """Parses `extras_require` configuration file section. @@ -682,12 +737,13 @@ class ConfigOptionsHandler(ConfigHandler): """ parse_list = partial(self._parse_list, separator=';') self['extras_require'] = self._parse_section_to_dict( - section_options, parse_list) + section_options, parse_list + ) def parse_section_data_files(self, section_options): """Parses `data_files` configuration file section. :param dict section_options: """ - parsed = self._parse_section_to_dict(section_options, self._parse_list) + parsed = self._parse_section_to_dict(section_options, self._parse_list_glob) self['data_files'] = [(k, v) for k, v in parsed.items()] diff --git a/setuptools/dist.py b/setuptools/dist.py index 2c088ef8..8e2111a5 100644 --- a/setuptools/dist.py +++ b/setuptools/dist.py @@ -11,10 +11,14 @@ import distutils.log import distutils.core import distutils.cmd import distutils.dist +import distutils.command from distutils.util import strtobool from distutils.debug import DEBUG from distutils.fancy_getopt import translate_longopt +from glob import iglob import itertools +import textwrap +from typing import List, Optional, TYPE_CHECKING from collections import defaultdict from email import message_from_file @@ -25,15 +29,20 @@ from distutils.version import StrictVersion from setuptools.extern import packaging from setuptools.extern import ordered_set +from setuptools.extern.more_itertools import unique_everseen from . import SetuptoolsDeprecationWarning import setuptools +import setuptools.command from setuptools import windows_support from setuptools.monkey import get_unpatched from setuptools.config import parse_configuration import pkg_resources +if TYPE_CHECKING: + from email.message import Message + __import__('setuptools.extern.packaging.specifiers') __import__('setuptools.extern.packaging.version') @@ -45,83 +54,108 @@ def _get_unpatched(cls): def get_metadata_version(self): mv = getattr(self, 'metadata_version', None) - if mv is None: - if self.long_description_content_type or self.provides_extras: - mv = StrictVersion('2.1') - elif (self.maintainer is not None or - self.maintainer_email is not None or - getattr(self, 'python_requires', None) is not None or - self.project_urls): - mv = StrictVersion('1.2') - elif (self.provides or self.requires or self.obsoletes or - self.classifiers or self.download_url): - mv = StrictVersion('1.1') - else: - mv = StrictVersion('1.0') - + mv = StrictVersion('2.1') self.metadata_version = mv - return mv -def read_pkg_file(self, file): - """Reads the metadata values from a file object.""" - msg = message_from_file(file) +def rfc822_unescape(content: str) -> str: + """Reverse RFC-822 escaping by removing leading whitespaces from content.""" + lines = content.splitlines() + if len(lines) == 1: + return lines[0].lstrip() + return '\n'.join((lines[0].lstrip(), textwrap.dedent('\n'.join(lines[1:])))) + - def _read_field(name): - value = msg[name] - if value == 'UNKNOWN': - return None +def _read_field_from_msg(msg: "Message", field: str) -> Optional[str]: + """Read Message header field.""" + value = msg[field] + if value == 'UNKNOWN': + return None + return value + + +def _read_field_unescaped_from_msg(msg: "Message", field: str) -> Optional[str]: + """Read Message header field and apply rfc822_unescape.""" + value = _read_field_from_msg(msg, field) + if value is None: return value + return rfc822_unescape(value) + - def _read_list(name): - values = msg.get_all(name, None) - if values == []: - return None - return values +def _read_list_from_msg(msg: "Message", field: str) -> Optional[List[str]]: + """Read Message header field and return all results as list.""" + values = msg.get_all(field, None) + if values == []: + return None + return values + + +def _read_payload_from_msg(msg: "Message") -> Optional[str]: + value = msg.get_payload().strip() + if value == 'UNKNOWN': + return None + return value + + +def read_pkg_file(self, file): + """Reads the metadata values from a file object.""" + msg = message_from_file(file) self.metadata_version = StrictVersion(msg['metadata-version']) - self.name = _read_field('name') - self.version = _read_field('version') - self.description = _read_field('summary') + self.name = _read_field_from_msg(msg, 'name') + self.version = _read_field_from_msg(msg, 'version') + self.description = _read_field_from_msg(msg, 'summary') # we are filling author only. - self.author = _read_field('author') + self.author = _read_field_from_msg(msg, 'author') self.maintainer = None - self.author_email = _read_field('author-email') + self.author_email = _read_field_from_msg(msg, 'author-email') self.maintainer_email = None - self.url = _read_field('home-page') - self.license = _read_field('license') + self.url = _read_field_from_msg(msg, 'home-page') + self.license = _read_field_unescaped_from_msg(msg, 'license') if 'download-url' in msg: - self.download_url = _read_field('download-url') + self.download_url = _read_field_from_msg(msg, 'download-url') else: self.download_url = None - self.long_description = _read_field('description') - self.description = _read_field('summary') + self.long_description = _read_field_unescaped_from_msg(msg, 'description') + if self.long_description is None and self.metadata_version >= StrictVersion('2.1'): + self.long_description = _read_payload_from_msg(msg) + self.description = _read_field_from_msg(msg, 'summary') if 'keywords' in msg: - self.keywords = _read_field('keywords').split(',') + self.keywords = _read_field_from_msg(msg, 'keywords').split(',') - self.platforms = _read_list('platform') - self.classifiers = _read_list('classifier') + self.platforms = _read_list_from_msg(msg, 'platform') + self.classifiers = _read_list_from_msg(msg, 'classifier') # PEP 314 - these fields only exist in 1.1 if self.metadata_version == StrictVersion('1.1'): - self.requires = _read_list('requires') - self.provides = _read_list('provides') - self.obsoletes = _read_list('obsoletes') + self.requires = _read_list_from_msg(msg, 'requires') + self.provides = _read_list_from_msg(msg, 'provides') + self.obsoletes = _read_list_from_msg(msg, 'obsoletes') else: self.requires = None self.provides = None self.obsoletes = None + self.license_files = _read_list_from_msg(msg, 'license-file') + + +def single_line(val): + # quick and dirty validation for description pypa/setuptools#1390 + if '\n' in val: + # TODO after 2021-07-31: Replace with `raise ValueError("newlines not allowed")` + warnings.warn("newlines not allowed and will break in the future") + val = val.replace('\n', ' ') + return val + # Based on Python 3.5 version -def write_pkg_file(self, file): - """Write the PKG-INFO format data to a file object. - """ +def write_pkg_file(self, file): # noqa: C901 # is too complex (14) # FIXME + """Write the PKG-INFO format data to a file object.""" version = self.get_metadata_version() def write_field(key, value): @@ -130,44 +164,34 @@ def write_pkg_file(self, file): write_field('Metadata-Version', str(version)) write_field('Name', self.get_name()) write_field('Version', self.get_version()) - write_field('Summary', self.get_description()) + write_field('Summary', single_line(self.get_description())) write_field('Home-page', self.get_url()) - if version < StrictVersion('1.2'): - write_field('Author', self.get_contact()) - write_field('Author-email', self.get_contact_email()) - else: - optional_fields = ( - ('Author', 'author'), - ('Author-email', 'author_email'), - ('Maintainer', 'maintainer'), - ('Maintainer-email', 'maintainer_email'), - ) - - for field, attr in optional_fields: - attr_val = getattr(self, attr) + optional_fields = ( + ('Author', 'author'), + ('Author-email', 'author_email'), + ('Maintainer', 'maintainer'), + ('Maintainer-email', 'maintainer_email'), + ) - if attr_val is not None: - write_field(field, attr_val) + for field, attr in optional_fields: + attr_val = getattr(self, attr, None) + if attr_val is not None: + write_field(field, attr_val) - write_field('License', self.get_license()) + license = rfc822_escape(self.get_license()) + write_field('License', license) if self.download_url: write_field('Download-URL', self.download_url) for project_url in self.project_urls.items(): write_field('Project-URL', '%s, %s' % project_url) - long_desc = rfc822_escape(self.get_long_description()) - write_field('Description', long_desc) - keywords = ','.join(self.get_keywords()) if keywords: write_field('Keywords', keywords) - if version >= StrictVersion('1.2'): - for platform in self.get_platforms(): - write_field('Platform', platform) - else: - self._write_list(file, 'Platform', self.get_platforms()) + for platform in self.get_platforms(): + write_field('Platform', platform) self._write_list(file, 'Classifier', self.get_classifiers()) @@ -182,14 +206,15 @@ def write_pkg_file(self, file): # PEP 566 if self.long_description_content_type: - write_field( - 'Description-Content-Type', - self.long_description_content_type - ) + write_field('Description-Content-Type', self.long_description_content_type) if self.provides_extras: for extra in self.provides_extras: write_field('Provides-Extra', extra) + self._write_list(file, 'License-File', self.license_files or []) + + file.write("\n%s\n\n" % self.get_long_description()) + sequence = tuple, list @@ -200,8 +225,7 @@ def check_importable(dist, attr, value): assert not ep.extras except (TypeError, ValueError, AttributeError, AssertionError) as e: raise DistutilsSetupError( - "%r must be importable 'module:attrs' string (got %r)" - % (attr, value) + "%r must be importable 'module:attrs' string (got %r)" % (attr, value) ) from e @@ -226,14 +250,16 @@ def check_nsp(dist, attr, value): for nsp in ns_packages: if not dist.has_contents_for(nsp): raise DistutilsSetupError( - "Distribution contains no modules or packages for " + - "namespace package %r" % nsp + "Distribution contains no modules or packages for " + + "namespace package %r" % nsp ) parent, sep, child = nsp.rpartition('.') if parent and parent not in ns_packages: distutils.log.warn( "WARNING: %r is declared as a package namespace, but %r" - " is not: please correct this in setup.py", nsp, parent + " is not: please correct this in setup.py", + nsp, + parent, ) @@ -263,6 +289,13 @@ def assert_bool(dist, attr, value): raise DistutilsSetupError(tmpl.format(attr=attr, value=value)) +def invalid_unless_false(dist, attr, value): + if not value: + warnings.warn(f"{attr} is ignored.", DistDeprecationWarning) + return + raise DistutilsSetupError(f"{attr} is invalid.") + + def check_requirements(dist, attr, value): """Verify that install_requires is a valid requirements list""" try: @@ -274,23 +307,18 @@ def check_requirements(dist, attr, value): "{attr!r} must be a string or list of strings " "containing valid project/version requirement specifiers; {error}" ) - raise DistutilsSetupError( - tmpl.format(attr=attr, error=error) - ) from error + raise DistutilsSetupError(tmpl.format(attr=attr, error=error)) from error def check_specifier(dist, attr, value): """Verify that value is a valid version specifier""" try: packaging.specifiers.SpecifierSet(value) - except packaging.specifiers.InvalidSpecifier as error: + except (packaging.specifiers.InvalidSpecifier, AttributeError) as error: tmpl = ( - "{attr!r} must be a string " - "containing valid version specifiers; {error}" + "{attr!r} must be a string " "containing valid version specifiers; {error}" ) - raise DistutilsSetupError( - tmpl.format(attr=attr, error=error) - ) from error + raise DistutilsSetupError(tmpl.format(attr=attr, error=error)) from error def check_entry_points(dist, attr, value): @@ -311,12 +339,12 @@ def check_package_data(dist, attr, value): if not isinstance(value, dict): raise DistutilsSetupError( "{!r} must be a dictionary mapping package names to lists of " - "string wildcard patterns".format(attr)) + "string wildcard patterns".format(attr) + ) for k, v in value.items(): if not isinstance(k, str): raise DistutilsSetupError( - "keys of {!r} dict must be strings (got {!r})" - .format(attr, k) + "keys of {!r} dict must be strings (got {!r})".format(attr, k) ) assert_string_list(dist, 'values of {!r} dict'.format(attr), v) @@ -326,7 +354,8 @@ def check_packages(dist, attr, value): if not re.match(r'\w+(\.\w+)*', pkgname): distutils.log.warn( "WARNING: %r not a valid package name; please use only " - ".-separated package names in setup.py", pkgname + ".-separated package names in setup.py", + pkgname, ) @@ -386,10 +415,11 @@ class Distribution(_Distribution): """ _DISTUTILS_UNSUPPORTED_METADATA = { - 'long_description_content_type': None, + 'long_description_content_type': lambda: None, 'project_urls': dict, 'provides_extras': ordered_set.OrderedSet, - 'license_files': ordered_set.OrderedSet, + 'license_file': lambda: None, + 'license_files': lambda: None, } _patched_dist = None @@ -420,27 +450,32 @@ class Distribution(_Distribution): self.setup_requires = attrs.pop('setup_requires', []) for ep in pkg_resources.iter_entry_points('distutils.setup_keywords'): vars(self).setdefault(ep.name, None) - _Distribution.__init__(self, { - k: v for k, v in attrs.items() - if k not in self._DISTUTILS_UNSUPPORTED_METADATA - }) - - # Fill-in missing metadata fields not supported by distutils. - # Note some fields may have been set by other tools (e.g. pbr) - # above; they are taken preferrentially to setup() arguments - for option, default in self._DISTUTILS_UNSUPPORTED_METADATA.items(): - for source in self.metadata.__dict__, attrs: - if option in source: - value = source[option] - break - else: - value = default() if default else None - setattr(self.metadata, option, value) + _Distribution.__init__( + self, + { + k: v + for k, v in attrs.items() + if k not in self._DISTUTILS_UNSUPPORTED_METADATA + }, + ) + + self._set_metadata_defaults(attrs) self.metadata.version = self._normalize_version( - self._validate_version(self.metadata.version)) + self._validate_version(self.metadata.version) + ) self._finalize_requires() + def _set_metadata_defaults(self, attrs): + """ + Fill-in missing metadata fields not supported by distutils. + Some fields may have been set by other tools (e.g. pbr). + Those fields (vars(self.metadata)) take precedence to + supplied attrs. + """ + for option, default in self._DISTUTILS_UNSUPPORTED_METADATA.items(): + vars(self.metadata).setdefault(option, attrs.get(option, default())) + @staticmethod def _normalize_version(version): if isinstance(version, setuptools.sic) or version is None: @@ -548,7 +583,42 @@ class Distribution(_Distribution): req.marker = None return req - def _parse_config_files(self, filenames=None): + def _finalize_license_files(self): + """Compute names of all license files which should be included.""" + license_files: Optional[List[str]] = self.metadata.license_files + patterns: List[str] = license_files if license_files else [] + + license_file: Optional[str] = self.metadata.license_file + if license_file and license_file not in patterns: + patterns.append(license_file) + + if license_files is None and license_file is None: + # Default patterns match the ones wheel uses + # See https://wheel.readthedocs.io/en/stable/user_guide.html + # -> 'Including license files in the generated wheel file' + patterns = ('LICEN[CS]E*', 'COPYING*', 'NOTICE*', 'AUTHORS*') + + self.metadata.license_files = list( + unique_everseen(self._expand_patterns(patterns)) + ) + + @staticmethod + def _expand_patterns(patterns): + """ + >>> list(Distribution._expand_patterns(['LICENSE'])) + ['LICENSE'] + >>> list(Distribution._expand_patterns(['setup.cfg', 'LIC*'])) + ['setup.cfg', 'LICENSE'] + """ + return ( + path + for pattern in patterns + for path in sorted(iglob(pattern)) + if not path.endswith('~') and os.path.isfile(path) + ) + + # FIXME: 'Distribution._parse_config_files' is too complex (14) + def _parse_config_files(self, filenames=None): # noqa: C901 """ Adapted from distutils.dist.Distribution.parse_config_files, this method provides the same functionality in subtly-improved @@ -557,14 +627,25 @@ class Distribution(_Distribution): from configparser import ConfigParser # Ignore install directory options if we have a venv - if sys.prefix != sys.base_prefix: - ignore_options = [ - 'install-base', 'install-platbase', 'install-lib', - 'install-platlib', 'install-purelib', 'install-headers', - 'install-scripts', 'install-data', 'prefix', 'exec-prefix', - 'home', 'user', 'root'] - else: - ignore_options = [] + ignore_options = ( + [] + if sys.prefix == sys.base_prefix + else [ + 'install-base', + 'install-platbase', + 'install-lib', + 'install-platlib', + 'install-purelib', + 'install-headers', + 'install-scripts', + 'install-data', + 'prefix', + 'exec-prefix', + 'home', + 'user', + 'root', + ] + ) ignore_options = frozenset(ignore_options) @@ -575,6 +656,7 @@ class Distribution(_Distribution): self.announce("Distribution.parse_config_files():") parser = ConfigParser() + parser.optionxform = str for filename in filenames: with io.open(filename, encoding='utf-8') as reader: if DEBUG: @@ -585,32 +667,82 @@ class Distribution(_Distribution): opt_dict = self.get_option_dict(section) for opt in options: - if opt != '__name__' and opt not in ignore_options: - val = parser.get(section, opt) - opt = opt.replace('-', '_') - opt_dict[opt] = (filename, val) + if opt == '__name__' or opt in ignore_options: + continue + + val = parser.get(section, opt) + opt = self.warn_dash_deprecation(opt, section) + opt = self.make_option_lowercase(opt, section) + opt_dict[opt] = (filename, val) # Make the ConfigParser forget everything (so we retain # the original filenames that options come from) parser.__init__() + if 'global' not in self.command_options: + return + # If there was a "global" section in the config file, use it # to set Distribution options. - if 'global' in self.command_options: - for (opt, (src, val)) in self.command_options['global'].items(): - alias = self.negative_opt.get(opt) - try: - if alias: - setattr(self, alias, not strtobool(val)) - elif opt in ('verbose', 'dry_run'): # ugh! - setattr(self, opt, strtobool(val)) - else: - setattr(self, opt, val) - except ValueError as e: - raise DistutilsOptionError(e) from e + for (opt, (src, val)) in self.command_options['global'].items(): + alias = self.negative_opt.get(opt) + if alias: + val = not strtobool(val) + elif opt in ('verbose', 'dry_run'): # ugh! + val = strtobool(val) + + try: + setattr(self, alias or opt, val) + except ValueError as e: + raise DistutilsOptionError(e) from e + + def warn_dash_deprecation(self, opt, section): + if section in ( + 'options.extras_require', + 'options.data_files', + ): + return opt + + underscore_opt = opt.replace('-', '_') + commands = distutils.command.__all__ + self._setuptools_commands() + if ( + not section.startswith('options') + and section != 'metadata' + and section not in commands + ): + return underscore_opt + + if '-' in opt: + warnings.warn( + "Usage of dash-separated '%s' will not be supported in future " + "versions. Please use the underscore name '%s' instead" + % (opt, underscore_opt) + ) + return underscore_opt - def _set_command_options(self, command_obj, option_dict=None): + def _setuptools_commands(self): + try: + dist = pkg_resources.get_distribution('setuptools') + return list(dist.get_entry_map('distutils.commands')) + except pkg_resources.DistributionNotFound: + # during bootstrapping, distribution doesn't exist + return [] + + def make_option_lowercase(self, opt, section): + if section != 'metadata' or opt.islower(): + return opt + + lowercase_opt = opt.lower() + warnings.warn( + "Usage of uppercase key '%s' in '%s' will be deprecated in future " + "versions. Please use lowercase '%s' instead" + % (opt, section, lowercase_opt) + ) + return lowercase_opt + + # FIXME: 'Distribution._set_command_options' is too complex (14) + def _set_command_options(self, command_obj, option_dict=None): # noqa: C901 """ Set the options for 'command_obj' from 'option_dict'. Basically this means copying elements of a dictionary ('option_dict') to @@ -630,11 +762,9 @@ class Distribution(_Distribution): self.announce(" setting options for '%s' command:" % command_name) for (option, (source, value)) in option_dict.items(): if DEBUG: - self.announce(" %s = %s (from %s)" % (option, value, - source)) + self.announce(" %s = %s (from %s)" % (option, value, source)) try: - bool_opts = [translate_longopt(o) - for o in command_obj.boolean_options] + bool_opts = [translate_longopt(o) for o in command_obj.boolean_options] except AttributeError: bool_opts = [] try: @@ -653,7 +783,8 @@ class Distribution(_Distribution): else: raise DistutilsOptionError( "error in %s: command '%s' has no such option '%s'" - % (source, command_name, option)) + % (source, command_name, option) + ) except ValueError as e: raise DistutilsOptionError(e) from e @@ -664,9 +795,11 @@ class Distribution(_Distribution): """ self._parse_config_files(filenames=filenames) - parse_configuration(self, self.command_options, - ignore_option_errors=ignore_option_errors) + parse_configuration( + self, self.command_options, ignore_option_errors=ignore_option_errors + ) self._finalize_requires() + self._finalize_license_files() def fetch_build_eggs(self, requires): """Resolve pre-setup requirements""" @@ -690,10 +823,27 @@ class Distribution(_Distribution): def by_order(hook): return getattr(hook, 'order', 0) - eps = map(lambda e: e.load(), pkg_resources.iter_entry_points(group)) - for ep in sorted(eps, key=by_order): + + defined = pkg_resources.iter_entry_points(group) + filtered = itertools.filterfalse(self._removed, defined) + loaded = map(lambda e: e.load(), filtered) + for ep in sorted(loaded, key=by_order): ep(self) + @staticmethod + def _removed(ep): + """ + When removing an entry point, if metadata is loaded + from an older version of Setuptools, that removed + entry point will attempt to be loaded and will fail. + See #2765 for more details. + """ + removed = { + # removed 2021-09-05 + '2to3_doctests', + } + return ep.name in removed + def _finalize_setup_keywords(self): for ep in pkg_resources.iter_entry_points('distutils.setup_keywords'): value = getattr(self, ep.name, None) @@ -701,16 +851,6 @@ class Distribution(_Distribution): ep.require(installer=self.fetch_build_egg) ep.load()(self, ep.name, value) - def _finalize_2to3_doctests(self): - if getattr(self, 'convert_2to3_doctests', None): - # XXX may convert to set here when we can rely on set being builtin - self.convert_2to3_doctests = [ - os.path.abspath(p) - for p in self.convert_2to3_doctests - ] - else: - self.convert_2to3_doctests = [] - def get_egg_cache_dir(self): egg_cache_dir = os.path.join(os.curdir, '.eggs') if not os.path.exists(egg_cache_dir): @@ -718,10 +858,14 @@ class Distribution(_Distribution): windows_support.hide_file(egg_cache_dir) readme_txt_filename = os.path.join(egg_cache_dir, 'README.txt') with open(readme_txt_filename, 'w') as f: - f.write('This directory contains eggs that were downloaded ' - 'by setuptools to build, test, and run plug-ins.\n\n') - f.write('This directory caches those eggs to prevent ' - 'repeated downloads.\n\n') + f.write( + 'This directory contains eggs that were downloaded ' + 'by setuptools to build, test, and run plug-ins.\n\n' + ) + f.write( + 'This directory caches those eggs to prevent ' + 'repeated downloads.\n\n' + ) f.write('However, it is safe to delete this directory.\n\n') return egg_cache_dir @@ -729,6 +873,7 @@ class Distribution(_Distribution): def fetch_build_egg(self, req): """Fetch an egg needed for building""" from setuptools.installer import fetch_build_egg + return fetch_build_egg(self, req) def get_command_class(self, command): @@ -788,19 +933,18 @@ class Distribution(_Distribution): pfx = package + '.' if self.packages: self.packages = [ - p for p in self.packages - if p != package and not p.startswith(pfx) + p for p in self.packages if p != package and not p.startswith(pfx) ] if self.py_modules: self.py_modules = [ - p for p in self.py_modules - if p != package and not p.startswith(pfx) + p for p in self.py_modules if p != package and not p.startswith(pfx) ] if self.ext_modules: self.ext_modules = [ - p for p in self.ext_modules + p + for p in self.ext_modules if p.name != package and not p.name.startswith(pfx) ] @@ -822,9 +966,7 @@ class Distribution(_Distribution): try: old = getattr(self, name) except AttributeError as e: - raise DistutilsSetupError( - "%s: No such distribution setting" % name - ) from e + raise DistutilsSetupError("%s: No such distribution setting" % name) from e if old is not None and not isinstance(old, sequence): raise DistutilsSetupError( name + ": this setting cannot be changed via include/exclude" @@ -836,15 +978,11 @@ class Distribution(_Distribution): """Handle 'include()' for list/tuple attrs without a special handler""" if not isinstance(value, sequence): - raise DistutilsSetupError( - "%s: setting must be a list (%r)" % (name, value) - ) + raise DistutilsSetupError("%s: setting must be a list (%r)" % (name, value)) try: old = getattr(self, name) except AttributeError as e: - raise DistutilsSetupError( - "%s: No such distribution setting" % name - ) from e + raise DistutilsSetupError("%s: No such distribution setting" % name) from e if old is None: setattr(self, name, value) elif not isinstance(old, sequence): @@ -897,6 +1035,7 @@ class Distribution(_Distribution): src, alias = aliases[command] del aliases[command] # ensure each alias can expand only once! import shlex + args[:1] = shlex.split(alias, True) command = args[0] @@ -996,12 +1135,14 @@ class Distribution(_Distribution): line_buffering = sys.stdout.line_buffering sys.stdout = io.TextIOWrapper( - sys.stdout.detach(), 'utf-8', errors, newline, line_buffering) + sys.stdout.detach(), 'utf-8', errors, newline, line_buffering + ) try: return _Distribution.handle_display_options(self, option_order) finally: sys.stdout = io.TextIOWrapper( - sys.stdout.detach(), encoding, errors, newline, line_buffering) + sys.stdout.detach(), encoding, errors, newline, line_buffering + ) class DistDeprecationWarning(SetuptoolsDeprecationWarning): diff --git a/setuptools/extern/__init__.py b/setuptools/extern/__init__.py index b7f30dc2..baca1afa 100644 --- a/setuptools/extern/__init__.py +++ b/setuptools/extern/__init__.py @@ -1,3 +1,4 @@ +import importlib.util import sys @@ -20,17 +21,10 @@ class VendorImporter: yield self.vendor_pkg + '.' yield '' - def find_module(self, fullname, path=None): - """ - Return self when fullname starts with root_name and the - target module is one vendored through this importer. - """ + def _module_matches_namespace(self, fullname): + """Figure out if the target module is vendored.""" root, base, target = fullname.partition(self.root_name + '.') - if root: - return - if not any(map(target.startswith, self.vendored_names)): - return - return self + return not root and any(map(target.startswith, self.vendored_names)) def load_module(self, fullname): """ @@ -54,6 +48,19 @@ class VendorImporter: "distribution.".format(**locals()) ) + def create_module(self, spec): + return self.load_module(spec.name) + + def exec_module(self, module): + pass + + def find_spec(self, fullname, path=None, target=None): + """Return a module spec for vendored names.""" + return ( + importlib.util.spec_from_loader(fullname, self) + if self._module_matches_namespace(fullname) else None + ) + def install(self): """ Install this importer into sys.meta_path if not already present. @@ -62,5 +69,5 @@ class VendorImporter: sys.meta_path.append(self) -names = 'packaging', 'pyparsing', 'ordered_set', +names = 'packaging', 'pyparsing', 'ordered_set', 'more_itertools', VendorImporter(__name__, names, 'setuptools._vendor').install() diff --git a/setuptools/glob.py b/setuptools/glob.py index 9d7cbc5d..87062b81 100644 --- a/setuptools/glob.py +++ b/setuptools/glob.py @@ -47,6 +47,8 @@ def iglob(pathname, recursive=False): def _iglob(pathname, recursive): dirname, basename = os.path.split(pathname) + glob_in_dir = glob2 if recursive and _isrecursive(basename) else glob1 + if not has_magic(pathname): if basename: if os.path.lexists(pathname): @@ -56,13 +58,9 @@ def _iglob(pathname, recursive): if os.path.isdir(dirname): yield pathname return + if not dirname: - if recursive and _isrecursive(basename): - for x in glob2(dirname, basename): - yield x - else: - for x in glob1(dirname, basename): - yield x + yield from glob_in_dir(dirname, basename) return # `os.path.split()` returns the argument itself as a dirname if it is a # drive or UNC path. Prevent an infinite recursion if a drive or UNC path @@ -71,12 +69,7 @@ def _iglob(pathname, recursive): dirs = _iglob(dirname, recursive) else: dirs = [dirname] - if has_magic(basename): - if recursive and _isrecursive(basename): - glob_in_dir = glob2 - else: - glob_in_dir = glob1 - else: + if not has_magic(basename): glob_in_dir = glob0 for dirname in dirs: for name in glob_in_dir(dirname, basename): diff --git a/setuptools/gui-arm64.exe b/setuptools/gui-arm64.exe Binary files differnew file mode 100644 index 00000000..5730f11d --- /dev/null +++ b/setuptools/gui-arm64.exe diff --git a/setuptools/installer.py b/setuptools/installer.py index e630b874..57e2b587 100644 --- a/setuptools/installer.py +++ b/setuptools/installer.py @@ -7,7 +7,6 @@ from distutils import log from distutils.errors import DistutilsError import pkg_resources -from setuptools.command.easy_install import easy_install from setuptools.wheel import Wheel @@ -19,54 +18,11 @@ def _fixup_find_links(find_links): return find_links -def _legacy_fetch_build_egg(dist, req): - """Fetch an egg needed for building. - - Legacy path using EasyInstall. - """ - tmp_dist = dist.__class__({'script_args': ['easy_install']}) - opts = tmp_dist.get_option_dict('easy_install') - opts.clear() - opts.update( - (k, v) - for k, v in dist.get_option_dict('easy_install').items() - if k in ( - # don't use any other settings - 'find_links', 'site_dirs', 'index_url', - 'optimize', 'site_dirs', 'allow_hosts', - )) - if dist.dependency_links: - links = dist.dependency_links[:] - if 'find_links' in opts: - links = _fixup_find_links(opts['find_links'][1]) + links - opts['find_links'] = ('setup', links) - install_dir = dist.get_egg_cache_dir() - cmd = easy_install( - tmp_dist, args=["x"], install_dir=install_dir, - exclude_scripts=True, - always_copy=False, build_directory=None, editable=False, - upgrade=False, multi_version=True, no_report=True, user=False - ) - cmd.ensure_finalized() - return cmd.easy_install(req) - - -def fetch_build_egg(dist, req): +def fetch_build_egg(dist, req): # noqa: C901 # is too complex (16) # FIXME """Fetch an egg needed for building. Use pip/wheel to fetch/build a wheel.""" - # Check pip is available. - try: - pkg_resources.get_distribution('pip') - except pkg_resources.DistributionNotFound: - dist.announce( - 'WARNING: The pip package is not available, falling back ' - 'to EasyInstall for handling setup_requires/test_requires; ' - 'this is deprecated and will be removed in a future version.', - log.WARN - ) - return _legacy_fetch_build_egg(dist, req) - # Warn if wheel is not. + # Warn if wheel is not available try: pkg_resources.get_distribution('wheel') except pkg_resources.DistributionNotFound: @@ -80,20 +36,17 @@ def fetch_build_egg(dist, req): if 'allow_hosts' in opts: raise DistutilsError('the `allow-hosts` option is not supported ' 'when using pip to install requirements.') - if 'PIP_QUIET' in os.environ or 'PIP_VERBOSE' in os.environ: - quiet = False - else: - quiet = True + quiet = 'PIP_QUIET' not in os.environ and 'PIP_VERBOSE' not in os.environ if 'PIP_INDEX_URL' in os.environ: index_url = None elif 'index_url' in opts: index_url = opts['index_url'][1] else: index_url = None - if 'find_links' in opts: - find_links = _fixup_find_links(opts['find_links'][1])[:] - else: - find_links = [] + find_links = ( + _fixup_find_links(opts['find_links'][1])[:] if 'find_links' in opts + else [] + ) if dist.dependency_links: find_links.extend(dist.dependency_links) eggs_dir = os.path.realpath(dist.get_egg_cache_dir()) @@ -112,16 +65,12 @@ def fetch_build_egg(dist, req): cmd.append('--quiet') if index_url is not None: cmd.extend(('--index-url', index_url)) - if find_links is not None: - for link in find_links: - cmd.extend(('--find-links', link)) + for link in find_links or []: + cmd.extend(('--find-links', link)) # If requirement is a PEP 508 direct URL, directly pass # the URL to pip, as `req @ url` does not work on the # command line. - if req.url: - cmd.append(req.url) - else: - cmd.append(str(req)) + cmd.append(req.url or str(req)) try: subprocess.check_call(cmd) except subprocess.CalledProcessError as e: diff --git a/setuptools/lib2to3_ex.py b/setuptools/lib2to3_ex.py deleted file mode 100644 index c176abf6..00000000 --- a/setuptools/lib2to3_ex.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Customized Mixin2to3 support: - - - adds support for converting doctests -""" - -import warnings -from distutils.util import Mixin2to3 as _Mixin2to3 -from distutils import log -from lib2to3.refactor import RefactoringTool, get_fixers_from_package - -import setuptools -from ._deprecation_warning import SetuptoolsDeprecationWarning - - -class DistutilsRefactoringTool(RefactoringTool): - def log_error(self, msg, *args, **kw): - log.error(msg, *args) - - def log_message(self, msg, *args): - log.info(msg, *args) - - def log_debug(self, msg, *args): - log.debug(msg, *args) - - -class Mixin2to3(_Mixin2to3): - def run_2to3(self, files, doctests=False): - # See of the distribution option has been set, otherwise check the - # setuptools default. - if self.distribution.use_2to3 is not True: - return - if not files: - return - - warnings.warn( - "2to3 support is deprecated. If the project still " - "requires Python 2 support, please migrate to " - "a single-codebase solution or employ an " - "independent conversion process.", - SetuptoolsDeprecationWarning) - log.info("Fixing " + " ".join(files)) - self.__build_fixer_names() - self.__exclude_fixers() - if doctests: - if setuptools.run_2to3_on_doctests: - r = DistutilsRefactoringTool(self.fixer_names) - r.refactor(files, write=True, doctests_only=True) - else: - _Mixin2to3.run_2to3(self, files) - - def __build_fixer_names(self): - if self.fixer_names: - return - self.fixer_names = [] - for p in setuptools.lib2to3_fixer_packages: - self.fixer_names.extend(get_fixers_from_package(p)) - if self.distribution.use_2to3_fixers is not None: - for p in self.distribution.use_2to3_fixers: - self.fixer_names.extend(get_fixers_from_package(p)) - - def __exclude_fixers(self): - excluded_fixers = getattr(self, 'exclude_fixers', []) - if self.distribution.use_2to3_exclude_fixers is not None: - excluded_fixers.extend(self.distribution.use_2to3_exclude_fixers) - for fixer_name in excluded_fixers: - if fixer_name in self.fixer_names: - self.fixer_names.remove(fixer_name) diff --git a/setuptools/msvc.py b/setuptools/msvc.py index 1ead72b4..281ea1c2 100644 --- a/setuptools/msvc.py +++ b/setuptools/msvc.py @@ -24,11 +24,13 @@ from io import open from os import listdir, pathsep from os.path import join, isfile, isdir, dirname import sys +import contextlib import platform import itertools import subprocess import distutils.errors from setuptools.extern.packaging.version import LegacyVersion +from setuptools.extern.more_itertools import unique_everseen from .monkey import get_unpatched @@ -192,7 +194,9 @@ def _msvc14_find_vc2017(): join(root, "Microsoft Visual Studio", "Installer", "vswhere.exe"), "-latest", "-prerelease", + "-requiresAny", "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", + "-requires", "Microsoft.VisualStudio.Workload.WDExpress", "-property", "installationPath", "-products", "*", ]).decode(encoding="mbcs", errors="strict").strip() @@ -724,28 +728,23 @@ class SystemInfo: ms = self.ri.microsoft vckeys = (self.ri.vc, self.ri.vc_for_python, self.ri.vs) vs_vers = [] - for hkey in self.ri.HKEYS: - for key in vckeys: - try: - bkey = winreg.OpenKey(hkey, ms(key), 0, winreg.KEY_READ) - except (OSError, IOError): - continue - with bkey: - subkeys, values, _ = winreg.QueryInfoKey(bkey) - for i in range(values): - try: - ver = float(winreg.EnumValue(bkey, i)[0]) - if ver not in vs_vers: - vs_vers.append(ver) - except ValueError: - pass - for i in range(subkeys): - try: - ver = float(winreg.EnumKey(bkey, i)) - if ver not in vs_vers: - vs_vers.append(ver) - except ValueError: - pass + for hkey, key in itertools.product(self.ri.HKEYS, vckeys): + try: + bkey = winreg.OpenKey(hkey, ms(key), 0, winreg.KEY_READ) + except (OSError, IOError): + continue + with bkey: + subkeys, values, _ = winreg.QueryInfoKey(bkey) + for i in range(values): + with contextlib.suppress(ValueError): + ver = float(winreg.EnumValue(bkey, i)[0]) + if ver not in vs_vers: + vs_vers.append(ver) + for i in range(subkeys): + with contextlib.suppress(ValueError): + ver = float(winreg.EnumKey(bkey, i)) + if ver not in vs_vers: + vs_vers.append(ver) return sorted(vs_vers) def find_programdata_vs_vers(self): @@ -925,8 +924,8 @@ class SystemInfo: """ return self._use_last_dir_name(join(self.WindowsSdkDir, 'lib')) - @property - def WindowsSdkDir(self): + @property # noqa: C901 + def WindowsSdkDir(self): # noqa: C901 # is too complex (12) # FIXME """ Microsoft Windows SDK directory. @@ -1802,29 +1801,5 @@ class EnvironmentInfo: if not extant_paths: msg = "%s environment variable is empty" % name.upper() raise distutils.errors.DistutilsPlatformError(msg) - unique_paths = self._unique_everseen(extant_paths) + unique_paths = unique_everseen(extant_paths) return pathsep.join(unique_paths) - - # from Python docs - @staticmethod - def _unique_everseen(iterable, key=None): - """ - List unique elements, preserving order. - Remember all elements ever seen. - - _unique_everseen('AAAABBBCCDAABBB') --> A B C D - - _unique_everseen('ABBCcAD', str.lower) --> A B C D - """ - seen = set() - seen_add = seen.add - if key is None: - for element in itertools.filterfalse(seen.__contains__, iterable): - seen_add(element) - yield element - else: - for element in iterable: - k = key(element) - if k not in seen: - seen_add(k) - yield element diff --git a/setuptools/package_index.py b/setuptools/package_index.py index bc0ba7a6..c9254289 100644 --- a/setuptools/package_index.py +++ b/setuptools/package_index.py @@ -23,11 +23,12 @@ from pkg_resources import ( Environment, find_distributions, safe_name, safe_version, to_filename, Requirement, DEVELOP_DIST, EGG_DIST, ) -from setuptools import ssl_support from distutils import log from distutils.errors import DistutilsError from fnmatch import translate from setuptools.wheel import Wheel +from setuptools.extern.more_itertools import unique_everseen + EGG_FRAGMENT = re.compile(r'^egg=([-A-Za-z0-9_.+!]+)$') HREF = re.compile(r"""href\s*=\s*['"]?([^'"> ]+)""", re.I) @@ -172,25 +173,6 @@ def interpret_distro_name( ) -# From Python 2.7 docs -def unique_everseen(iterable, key=None): - "List unique elements, preserving order. Remember all elements ever seen." - # unique_everseen('AAAABBBCCDAABBB') --> A B C D - # unique_everseen('ABBCcAD', str.lower) --> A B C D - seen = set() - seen_add = seen.add - if key is None: - for element in itertools.filterfalse(seen.__contains__, iterable): - seen_add(element) - yield element - else: - for element in iterable: - k = key(element) - if k not in seen: - seen_add(k) - yield element - - def unique_values(func): """ Wrap a function returning an iterable such that the resulting iterable @@ -299,17 +281,10 @@ class PackageIndex(Environment): self.package_pages = {} self.allows = re.compile('|'.join(map(translate, hosts))).match self.to_scan = [] - use_ssl = ( - verify_ssl - and ssl_support.is_available - and (ca_bundle or ssl_support.find_ca_bundle()) - ) - if use_ssl: - self.opener = ssl_support.opener_for(ca_bundle) - else: - self.opener = urllib.request.urlopen + self.opener = urllib.request.urlopen - def process_url(self, url, retrieve=False): + # FIXME: 'PackageIndex.process_url' is too complex (14) + def process_url(self, url, retrieve=False): # noqa: C901 """Evaluate a URL as a possible download, and maybe retrieve it""" if url in self.scanned_urls and not retrieve: return @@ -417,49 +392,53 @@ class PackageIndex(Environment): dist.precedence = SOURCE_DIST self.add(dist) + def _scan(self, link): + # Process a URL to see if it's for a package page + NO_MATCH_SENTINEL = None, None + if not link.startswith(self.index_url): + return NO_MATCH_SENTINEL + + parts = list(map( + urllib.parse.unquote, link[len(self.index_url):].split('/') + )) + if len(parts) != 2 or '#' in parts[1]: + return NO_MATCH_SENTINEL + + # it's a package page, sanitize and index it + pkg = safe_name(parts[0]) + ver = safe_version(parts[1]) + self.package_pages.setdefault(pkg.lower(), {})[link] = True + return to_filename(pkg), to_filename(ver) + def process_index(self, url, page): """Process the contents of a PyPI page""" - def scan(link): - # Process a URL to see if it's for a package page - if link.startswith(self.index_url): - parts = list(map( - urllib.parse.unquote, link[len(self.index_url):].split('/') - )) - if len(parts) == 2 and '#' not in parts[1]: - # it's a package page, sanitize and index it - pkg = safe_name(parts[0]) - ver = safe_version(parts[1]) - self.package_pages.setdefault(pkg.lower(), {})[link] = True - return to_filename(pkg), to_filename(ver) - return None, None - # process an index page into the package-page index for match in HREF.finditer(page): try: - scan(urllib.parse.urljoin(url, htmldecode(match.group(1)))) + self._scan(urllib.parse.urljoin(url, htmldecode(match.group(1)))) except ValueError: pass - pkg, ver = scan(url) # ensure this page is in the page index - if pkg: - # process individual package page - for new_url in find_external_links(url, page): - # Process the found URL - base, frag = egg_info_for_url(new_url) - if base.endswith('.py') and not frag: - if ver: - new_url += '#egg=%s-%s' % (pkg, ver) - else: - self.need_version_info(url) - self.scan_url(new_url) - - return PYPI_MD5.sub( - lambda m: '<a href="%s#md5=%s">%s</a>' % m.group(1, 3, 2), page - ) - else: + pkg, ver = self._scan(url) # ensure this page is in the page index + if not pkg: return "" # no sense double-scanning non-package pages + # process individual package page + for new_url in find_external_links(url, page): + # Process the found URL + base, frag = egg_info_for_url(new_url) + if base.endswith('.py') and not frag: + if ver: + new_url += '#egg=%s-%s' % (pkg, ver) + else: + self.need_version_info(url) + self.scan_url(new_url) + + return PYPI_MD5.sub( + lambda m: '<a href="%s#md5=%s">%s</a>' % m.group(1, 3, 2), page + ) + def need_version_info(self, url): self.scan_all( "Page at %s links to .py file(s) without version info; an index " @@ -580,7 +559,7 @@ class PackageIndex(Environment): spec = parse_requirement_arg(spec) return getattr(self.fetch_distribution(spec, tmpdir), 'location', None) - def fetch_distribution( + def fetch_distribution( # noqa: C901 # is too complex (14) # FIXME self, requirement, tmpdir, force_scan=False, source=False, develop_ok=False, local_index=None): """Obtain a distribution suitable for fulfilling `requirement` @@ -751,7 +730,8 @@ class PackageIndex(Environment): def reporthook(self, url, filename, blocknum, blksize, size): pass # no-op - def open_url(self, url, warning=None): + # FIXME: + def open_url(self, url, warning=None): # noqa: C901 # is too complex (12) if url.startswith('file:'): return local_open(url) try: diff --git a/setuptools/sandbox.py b/setuptools/sandbox.py index 91b960d8..034fc80d 100644 --- a/setuptools/sandbox.py +++ b/setuptools/sandbox.py @@ -26,7 +26,10 @@ _open = open __all__ = [ - "AbstractSandbox", "DirectorySandbox", "SandboxViolation", "run_setup", + "AbstractSandbox", + "DirectorySandbox", + "SandboxViolation", + "run_setup", ] @@ -106,6 +109,7 @@ class UnpickleableException(Exception): except Exception: # get UnpickleableException inside the sandbox from setuptools.sandbox import UnpickleableException as cls + return cls.dump(cls, cls(repr(exc))) @@ -154,7 +158,8 @@ def save_modules(): sys.modules.update(saved) # remove any modules imported since del_modules = ( - mod_name for mod_name in sys.modules + mod_name + for mod_name in sys.modules if mod_name not in saved # exclude any encodings modules. See #285 and not mod_name.startswith('encodings.') @@ -265,7 +270,8 @@ class AbstractSandbox: def __init__(self): self._attrs = [ - name for name in dir(_os) + name + for name in dir(_os) if not name.startswith('_') and hasattr(self, name) ] @@ -320,9 +326,25 @@ class AbstractSandbox: _file = _mk_single_path_wrapper('file', _file) _open = _mk_single_path_wrapper('open', _open) for name in [ - "stat", "listdir", "chdir", "open", "chmod", "chown", "mkdir", - "remove", "unlink", "rmdir", "utime", "lchown", "chroot", "lstat", - "startfile", "mkfifo", "mknod", "pathconf", "access" + "stat", + "listdir", + "chdir", + "open", + "chmod", + "chown", + "mkdir", + "remove", + "unlink", + "rmdir", + "utime", + "lchown", + "chroot", + "lstat", + "startfile", + "mkfifo", + "mknod", + "pathconf", + "access", ]: if hasattr(_os, name): locals()[name] = _mk_single_path_wrapper(name) @@ -373,7 +395,7 @@ class AbstractSandbox: """Called for path pairs like rename, link, and symlink operations""" return ( self._remap_input(operation + '-from', src, *args, **kw), - self._remap_input(operation + '-to', dst, *args, **kw) + self._remap_input(operation + '-to', dst, *args, **kw), ) @@ -386,28 +408,38 @@ else: class DirectorySandbox(AbstractSandbox): """Restrict operations to a single subdirectory - pseudo-chroot""" - write_ops = dict.fromkeys([ - "open", "chmod", "chown", "mkdir", "remove", "unlink", "rmdir", - "utime", "lchown", "chroot", "mkfifo", "mknod", "tempnam", - ]) + write_ops = dict.fromkeys( + [ + "open", + "chmod", + "chown", + "mkdir", + "remove", + "unlink", + "rmdir", + "utime", + "lchown", + "chroot", + "mkfifo", + "mknod", + "tempnam", + ] + ) - _exception_patterns = [ - # Allow lib2to3 to attempt to save a pickled grammar object (#121) - r'.*lib2to3.*\.pickle$', - ] + _exception_patterns = [] "exempt writing to paths that match the pattern" def __init__(self, sandbox, exceptions=_EXCEPTIONS): self._sandbox = os.path.normcase(os.path.realpath(sandbox)) self._prefix = os.path.join(self._sandbox, '') self._exceptions = [ - os.path.normcase(os.path.realpath(path)) - for path in exceptions + os.path.normcase(os.path.realpath(path)) for path in exceptions ] AbstractSandbox.__init__(self) def _violation(self, operation, *args, **kw): from setuptools.sandbox import SandboxViolation + raise SandboxViolation(operation, args, kw) if _file: @@ -440,12 +472,10 @@ class DirectorySandbox(AbstractSandbox): def _exempted(self, filepath): start_matches = ( - filepath.startswith(exception) - for exception in self._exceptions + filepath.startswith(exception) for exception in self._exceptions ) pattern_matches = ( - re.match(pattern, filepath) - for pattern in self._exception_patterns + re.match(pattern, filepath) for pattern in self._exception_patterns ) candidates = itertools.chain(start_matches, pattern_matches) return any(candidates) @@ -470,16 +500,19 @@ class DirectorySandbox(AbstractSandbox): WRITE_FLAGS = functools.reduce( - operator.or_, [ - getattr(_os, a, 0) for a in - "O_WRONLY O_RDWR O_APPEND O_CREAT O_TRUNC O_TEMPORARY".split()] + operator.or_, + [ + getattr(_os, a, 0) + for a in "O_WRONLY O_RDWR O_APPEND O_CREAT O_TRUNC O_TEMPORARY".split() + ], ) class SandboxViolation(DistutilsError): """A setup script attempted to modify the filesystem outside the sandbox""" - tmpl = textwrap.dedent(""" + tmpl = textwrap.dedent( + """ SandboxViolation: {cmd}{args!r} {kwargs} The package setup script has attempted to modify files on your system @@ -489,7 +522,8 @@ class SandboxViolation(DistutilsError): support alternate installation locations even if you run its setup script by hand. Please inform the package's author and the EasyInstall maintainers to find out if a fix or workaround is available. - """).lstrip() + """ + ).lstrip() def __str__(self): cmd, args, kwargs = self.args diff --git a/setuptools/ssl_support.py b/setuptools/ssl_support.py deleted file mode 100644 index eac5e656..00000000 --- a/setuptools/ssl_support.py +++ /dev/null @@ -1,266 +0,0 @@ -import os -import socket -import atexit -import re -import functools -import urllib.request -import http.client - - -from pkg_resources import ResolutionError, ExtractionError - -try: - import ssl -except ImportError: - ssl = None - -__all__ = [ - 'VerifyingHTTPSHandler', 'find_ca_bundle', 'is_available', 'cert_paths', - 'opener_for' -] - -cert_paths = """ -/etc/pki/tls/certs/ca-bundle.crt -/etc/ssl/certs/ca-certificates.crt -/usr/share/ssl/certs/ca-bundle.crt -/usr/local/share/certs/ca-root.crt -/etc/ssl/cert.pem -/System/Library/OpenSSL/certs/cert.pem -/usr/local/share/certs/ca-root-nss.crt -/etc/ssl/ca-bundle.pem -""".strip().split() - -try: - HTTPSHandler = urllib.request.HTTPSHandler - HTTPSConnection = http.client.HTTPSConnection -except AttributeError: - HTTPSHandler = HTTPSConnection = object - -is_available = ssl is not None and object not in ( - HTTPSHandler, HTTPSConnection) - - -try: - from ssl import CertificateError, match_hostname -except ImportError: - try: - from backports.ssl_match_hostname import CertificateError - from backports.ssl_match_hostname import match_hostname - except ImportError: - CertificateError = None - match_hostname = None - -if not CertificateError: - - class CertificateError(ValueError): - pass - - -if not match_hostname: - - def _dnsname_match(dn, hostname, max_wildcards=1): - """Matching according to RFC 6125, section 6.4.3 - - https://tools.ietf.org/html/rfc6125#section-6.4.3 - """ - pats = [] - if not dn: - return False - - # Ported from python3-syntax: - # leftmost, *remainder = dn.split(r'.') - parts = dn.split(r'.') - leftmost = parts[0] - remainder = parts[1:] - - wildcards = leftmost.count('*') - if wildcards > max_wildcards: - # Issue #17980: avoid denials of service by refusing more - # than one wildcard per fragment. A survey of established - # policy among SSL implementations showed it to be a - # reasonable choice. - raise CertificateError( - "too many wildcards in certificate DNS name: " + repr(dn)) - - # speed up common case w/o wildcards - if not wildcards: - return dn.lower() == hostname.lower() - - # RFC 6125, section 6.4.3, subitem 1. - # The client SHOULD NOT attempt to match a - # presented identifier in which the wildcard - # character comprises a label other than the - # left-most label. - if leftmost == '*': - # When '*' is a fragment by itself, it matches a non-empty dotless - # fragment. - pats.append('[^.]+') - elif leftmost.startswith('xn--') or hostname.startswith('xn--'): - # RFC 6125, section 6.4.3, subitem 3. - # The client SHOULD NOT attempt to match a presented identifier - # where the wildcard character is embedded within an A-label or - # U-label of an internationalized domain name. - pats.append(re.escape(leftmost)) - else: - # Otherwise, '*' matches any dotless string, e.g. www* - pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) - - # add the remaining fragments, ignore any wildcards - for frag in remainder: - pats.append(re.escape(frag)) - - pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) - return pat.match(hostname) - - def match_hostname(cert, hostname): - """Verify that *cert* (in decoded format as returned by - SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 - rules are followed, but IP addresses are not accepted for *hostname*. - - CertificateError is raised on failure. On success, the function - returns nothing. - """ - if not cert: - raise ValueError("empty or no certificate") - dnsnames = [] - san = cert.get('subjectAltName', ()) - for key, value in san: - if key == 'DNS': - if _dnsname_match(value, hostname): - return - dnsnames.append(value) - if not dnsnames: - # The subject is only checked when there is no dNSName entry - # in subjectAltName - for sub in cert.get('subject', ()): - for key, value in sub: - # XXX according to RFC 2818, the most specific Common Name - # must be used. - if key == 'commonName': - if _dnsname_match(value, hostname): - return - dnsnames.append(value) - if len(dnsnames) > 1: - raise CertificateError( - "hostname %r doesn't match either of %s" - % (hostname, ', '.join(map(repr, dnsnames)))) - elif len(dnsnames) == 1: - raise CertificateError( - "hostname %r doesn't match %r" - % (hostname, dnsnames[0])) - else: - raise CertificateError( - "no appropriate commonName or " - "subjectAltName fields were found") - - -class VerifyingHTTPSHandler(HTTPSHandler): - """Simple verifying handler: no auth, subclasses, timeouts, etc.""" - - def __init__(self, ca_bundle): - self.ca_bundle = ca_bundle - HTTPSHandler.__init__(self) - - def https_open(self, req): - return self.do_open( - lambda host, **kw: VerifyingHTTPSConn(host, self.ca_bundle, **kw), - req - ) - - -class VerifyingHTTPSConn(HTTPSConnection): - """Simple verifying connection: no auth, subclasses, timeouts, etc.""" - - def __init__(self, host, ca_bundle, **kw): - HTTPSConnection.__init__(self, host, **kw) - self.ca_bundle = ca_bundle - - def connect(self): - sock = socket.create_connection( - (self.host, self.port), getattr(self, 'source_address', None) - ) - - # Handle the socket if a (proxy) tunnel is present - if hasattr(self, '_tunnel') and getattr(self, '_tunnel_host', None): - self.sock = sock - self._tunnel() - # http://bugs.python.org/issue7776: Python>=3.4.1 and >=2.7.7 - # change self.host to mean the proxy server host when tunneling is - # being used. Adapt, since we are interested in the destination - # host for the match_hostname() comparison. - actual_host = self._tunnel_host - else: - actual_host = self.host - - if hasattr(ssl, 'create_default_context'): - ctx = ssl.create_default_context(cafile=self.ca_bundle) - self.sock = ctx.wrap_socket(sock, server_hostname=actual_host) - else: - # This is for python < 2.7.9 and < 3.4? - self.sock = ssl.wrap_socket( - sock, cert_reqs=ssl.CERT_REQUIRED, ca_certs=self.ca_bundle - ) - try: - match_hostname(self.sock.getpeercert(), actual_host) - except CertificateError: - self.sock.shutdown(socket.SHUT_RDWR) - self.sock.close() - raise - - -def opener_for(ca_bundle=None): - """Get a urlopen() replacement that uses ca_bundle for verification""" - return urllib.request.build_opener( - VerifyingHTTPSHandler(ca_bundle or find_ca_bundle()) - ).open - - -# from jaraco.functools -def once(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if not hasattr(func, 'always_returns'): - func.always_returns = func(*args, **kwargs) - return func.always_returns - return wrapper - - -@once -def get_win_certfile(): - try: - import wincertstore - except ImportError: - return None - - class CertFile(wincertstore.CertFile): - def __init__(self): - super(CertFile, self).__init__() - atexit.register(self.close) - - def close(self): - try: - super(CertFile, self).close() - except OSError: - pass - - _wincerts = CertFile() - _wincerts.addstore('CA') - _wincerts.addstore('ROOT') - return _wincerts.name - - -def find_ca_bundle(): - """Return an existing CA bundle path, or None""" - extant_cert_paths = filter(os.path.isfile, cert_paths) - return ( - get_win_certfile() - or next(extant_cert_paths, None) - or _certifi_where() - ) - - -def _certifi_where(): - try: - return __import__('certifi').where() - except (ImportError, ResolutionError, ExtractionError): - pass diff --git a/setuptools/tests/__init__.py b/setuptools/tests/__init__.py index a7a2112f..564adf2b 100644 --- a/setuptools/tests/__init__.py +++ b/setuptools/tests/__init__.py @@ -3,11 +3,8 @@ import locale import pytest -__all__ = ['fail_on_ascii', 'ack_2to3'] +__all__ = ['fail_on_ascii'] is_ascii = locale.getpreferredencoding() == 'ANSI_X3.4-1968' fail_on_ascii = pytest.mark.xfail(is_ascii, reason="Test fails in this locale") - - -ack_2to3 = pytest.mark.filterwarnings('ignore:2to3 support is deprecated') diff --git a/setuptools/tests/environment.py b/setuptools/tests/environment.py index bd3119ef..c0274c33 100644 --- a/setuptools/tests/environment.py +++ b/setuptools/tests/environment.py @@ -29,7 +29,7 @@ def run_setup_py(cmd, pypath=None, path=None, if pypath is not None: env["PYTHONPATH"] = pypath - # overide the execution path if needed + # override the execution path if needed if path is not None: env["PATH"] = path if not env.get("PATH", ""): diff --git a/setuptools/tests/files.py b/setuptools/tests/files.py deleted file mode 100644 index 71194b9d..00000000 --- a/setuptools/tests/files.py +++ /dev/null @@ -1,38 +0,0 @@ -import os - - -def build_files(file_defs, prefix=""): - """ - Build a set of files/directories, as described by the - file_defs dictionary. - - Each key/value pair in the dictionary is interpreted as - a filename/contents - pair. If the contents value is a dictionary, a directory - is created, and the - dictionary interpreted as the files within it, recursively. - - For example: - - {"README.txt": "A README file", - "foo": { - "__init__.py": "", - "bar": { - "__init__.py": "", - }, - "baz.py": "# Some code", - } - } - """ - for name, contents in file_defs.items(): - full_name = os.path.join(prefix, name) - if isinstance(contents, dict): - os.makedirs(full_name, exist_ok=True) - build_files(contents, prefix=full_name) - else: - if isinstance(contents, bytes): - with open(full_name, 'wb') as f: - f.write(contents) - else: - with open(full_name, 'w') as f: - f.write(contents) diff --git a/setuptools/tests/fixtures.py b/setuptools/tests/fixtures.py index e8cb7f52..a5a172e0 100644 --- a/setuptools/tests/fixtures.py +++ b/setuptools/tests/fixtures.py @@ -1,3 +1,8 @@ +import contextlib +import sys +import shutil +import subprocess + import pytest from . import contexts @@ -21,3 +26,49 @@ def user_override(monkeypatch): def tmpdir_cwd(tmpdir): with tmpdir.as_cwd() as orig: yield orig + + +@pytest.fixture +def tmp_src(request, tmp_path): + """Make a copy of the source dir under `$tmp/src`. + + This fixture is useful whenever it's necessary to run `setup.py` + or `pip install` against the source directory when there's no + control over the number of simultaneous invocations. Such + concurrent runs create and delete directories with the same names + under the target directory and so they influence each other's runs + when they are not being executed sequentially. + """ + tmp_src_path = tmp_path / 'src' + shutil.copytree(request.config.rootdir, tmp_src_path) + return tmp_src_path + + +@pytest.fixture(autouse=True, scope="session") +def workaround_xdist_376(request): + """ + Workaround pytest-dev/pytest-xdist#376 + + ``pytest-xdist`` tends to inject '' into ``sys.path``, + which may break certain isolation expectations. + Remove the entry so the import + machinery behaves the same irrespective of xdist. + """ + if not request.config.pluginmanager.has_plugin('xdist'): + return + + with contextlib.suppress(ValueError): + sys.path.remove('') + + +@pytest.fixture +def sample_project(tmp_path): + """ + Clone the 'sampleproject' and return a path to it. + """ + cmd = ['git', 'clone', 'https://github.com/pypa/sampleproject'] + try: + subprocess.check_call(cmd, cwd=str(tmp_path)) + except Exception: + pytest.skip("Unable to clone sampleproject") + return tmp_path / 'sampleproject' diff --git a/setuptools/tests/requirements.txt b/setuptools/tests/requirements.txt index d0d07f70..b2d84a94 100644 --- a/setuptools/tests/requirements.txt +++ b/setuptools/tests/requirements.txt @@ -11,3 +11,4 @@ paver; python_version>="3.6" futures; python_version=="2.7" pip>=19.1 # For proper file:// URLs support. jaraco.envs +sphinx diff --git a/setuptools/tests/server.py b/setuptools/tests/server.py index 7e213230..6717c053 100644 --- a/setuptools/tests/server.py +++ b/setuptools/tests/server.py @@ -65,7 +65,7 @@ class MockServer(http.server.HTTPServer, threading.Thread): http.server.HTTPServer.__init__( self, server_address, RequestHandlerClass) threading.Thread.__init__(self) - self.setDaemon(True) + self.daemon = True self.requests = [] def run(self): diff --git a/setuptools/tests/test_bdist_deprecations.py b/setuptools/tests/test_bdist_deprecations.py index 704164aa..28482fd0 100644 --- a/setuptools/tests/test_bdist_deprecations.py +++ b/setuptools/tests/test_bdist_deprecations.py @@ -1,6 +1,7 @@ """develop tests """ import mock +import sys import pytest @@ -8,14 +9,17 @@ from setuptools.dist import Distribution from setuptools import SetuptoolsDeprecationWarning -@mock.patch("distutils.command.bdist_wininst.bdist_wininst") -def test_bdist_wininst_warning(distutils_cmd): - dist = Distribution(dict( - script_name='setup.py', - script_args=['bdist_wininst'], - name='foo', - py_modules=['hi'], - )) +@pytest.mark.skipif(sys.platform == 'win32', reason='non-Windows only') +@mock.patch('distutils.command.bdist_rpm.bdist_rpm') +def test_bdist_rpm_warning(distutils_cmd): + dist = Distribution( + dict( + script_name='setup.py', + script_args=['bdist_rpm'], + name='foo', + py_modules=['hi'], + ) + ) dist.parse_command_line() with pytest.warns(SetuptoolsDeprecationWarning): dist.run_commands() diff --git a/setuptools/tests/test_bdist_egg.py b/setuptools/tests/test_bdist_egg.py index 8760ea30..fb5b90b1 100644 --- a/setuptools/tests/test_bdist_egg.py +++ b/setuptools/tests/test_bdist_egg.py @@ -7,7 +7,6 @@ import zipfile import pytest from setuptools.dist import Distribution -from setuptools import SetuptoolsDeprecationWarning from . import contexts @@ -65,17 +64,3 @@ class Test: names = list(zi.filename for zi in zip.filelist) assert 'hi.pyc' in names assert 'hi.py' not in names - - def test_eggsecutable_warning(self, setup_context, user_override): - dist = Distribution(dict( - script_name='setup.py', - script_args=['bdist_egg'], - name='foo', - py_modules=['hi'], - entry_points={ - 'setuptools.installation': - ['eggsecutable = my_package.some_module:main_func']}, - )) - dist.parse_command_line() - with pytest.warns(SetuptoolsDeprecationWarning): - dist.run_commands() diff --git a/setuptools/tests/test_build_ext.py b/setuptools/tests/test_build_ext.py index 838fdb42..3177a2cd 100644 --- a/setuptools/tests/test_build_ext.py +++ b/setuptools/tests/test_build_ext.py @@ -1,16 +1,21 @@ +import os import sys import distutils.command.build_ext as orig from distutils.sysconfig import get_config_var +from jaraco import path + from setuptools.command.build_ext import build_ext, get_abi3_suffix from setuptools.dist import Distribution from setuptools.extension import Extension from . import environment -from .files import build_files from .textwrap import DALS +IS_PYPY = '__pypy__' in sys.builtin_module_names + + class TestBuildExt: def test_get_ext_filename(self): """ @@ -46,6 +51,38 @@ class TestBuildExt: else: assert 'abi3' in res + def test_ext_suffix_override(self): + """ + SETUPTOOLS_EXT_SUFFIX variable always overrides + default extension options. + """ + dist = Distribution() + cmd = build_ext(dist) + cmd.ext_map['for_abi3'] = ext = Extension( + 'for_abi3', + ['s.c'], + # Override shouldn't affect abi3 modules + py_limited_api=True, + ) + # Mock value needed to pass tests + ext._links_to_dynamic = False + + if not IS_PYPY: + expect = cmd.get_ext_filename('for_abi3') + else: + # PyPy builds do not use ABI3 tag, so they will + # also get the overridden suffix. + expect = 'for_abi3.test-suffix' + + try: + os.environ['SETUPTOOLS_EXT_SUFFIX'] = '.test-suffix' + res = cmd.get_ext_filename('normal') + assert 'normal.test-suffix' == res + res = cmd.get_ext_filename('for_abi3') + assert expect == res + finally: + del os.environ['SETUPTOOLS_EXT_SUFFIX'] + def test_build_ext_config_handling(tmpdir_cwd): files = { @@ -103,10 +140,10 @@ def test_build_ext_config_handling(tmpdir_cwd): 'setup.cfg': DALS( """ [build] - build-base = foo_build + build_base = foo_build """), } - build_files(files) + path.build(files) code, output = environment.run_setup_py( cmd=['build'], data_stream=(0, 2), ) diff --git a/setuptools/tests/test_build_meta.py b/setuptools/tests/test_build_meta.py index 6d3a997e..0f4a1a73 100644 --- a/setuptools/tests/test_build_meta.py +++ b/setuptools/tests/test_build_meta.py @@ -3,15 +3,16 @@ import shutil import tarfile import importlib from concurrent import futures +import re import pytest +from jaraco import path -from .files import build_files from .textwrap import DALS class BuildBackendBase: - def __init__(self, cwd=None, env={}, backend_name='setuptools.build_meta'): + def __init__(self, cwd='.', env={}, backend_name='setuptools.build_meta'): self.cwd = cwd self.env = env self.backend_name = backend_name @@ -126,11 +127,11 @@ class TestBuildMetaBackend: backend_name = 'setuptools.build_meta' def get_build_backend(self): - return BuildBackend(cwd='.', backend_name=self.backend_name) + return BuildBackend(backend_name=self.backend_name) @pytest.fixture(params=defns) def build_backend(self, tmpdir, request): - build_files(request.param, prefix=str(tmpdir)) + path.build(request.param, prefix=str(tmpdir)) with tmpdir.as_cwd(): yield self.get_build_backend() @@ -166,11 +167,11 @@ class TestBuildMetaBackend: 'pyproject.toml': DALS(""" [build-system] requires = ["setuptools", "wheel"] - build-backend = "setuptools.build_meta + build-backend = "setuptools.build_meta" """), } - build_files(files) + path.build(files) dist_dir = os.path.abspath('preexisting-' + build_type) @@ -259,10 +260,10 @@ class TestBuildMetaBackend: 'pyproject.toml': DALS(""" [build-system] requires = ["setuptools", "wheel"] - build-backend = "setuptools.build_meta + build-backend = "setuptools.build_meta" """), } - build_files(files) + path.build(files) build_backend = self.get_build_backend() targz_path = build_backend.build_sdist("temp") with tarfile.open(os.path.join("temp", targz_path)) as tar: @@ -271,7 +272,7 @@ class TestBuildMetaBackend: def test_build_sdist_setup_py_exists(self, tmpdir_cwd): # If build_sdist is called from a script other than setup.py, # ensure setup.py is included - build_files(defns[0]) + path.build(defns[0]) build_backend = self.get_build_backend() targz_path = build_backend.build_sdist("temp") @@ -293,7 +294,7 @@ class TestBuildMetaBackend: """) } - build_files(files) + path.build(files) build_backend = self.get_build_backend() targz_path = build_backend.build_sdist("temp") @@ -315,7 +316,7 @@ class TestBuildMetaBackend: """) } - build_files(files) + path.build(files) build_backend = self.get_build_backend() build_backend.build_sdist("temp") @@ -335,9 +336,9 @@ class TestBuildMetaBackend: } def test_build_sdist_relative_path_import(self, tmpdir_cwd): - build_files(self._relative_path_import_files) + path.build(self._relative_path_import_files) build_backend = self.get_build_backend() - with pytest.raises(ImportError): + with pytest.raises(ImportError, match="^No module named 'hello'$"): build_backend.build_sdist("temp") @pytest.mark.parametrize('setup_literal, requirements', [ @@ -374,7 +375,7 @@ class TestBuildMetaBackend: """), } - build_files(files) + path.build(files) build_backend = self.get_build_backend() @@ -409,7 +410,7 @@ class TestBuildMetaBackend: """), } - build_files(files) + path.build(files) build_backend = self.get_build_backend() @@ -437,11 +438,21 @@ class TestBuildMetaBackend: } def test_sys_argv_passthrough(self, tmpdir_cwd): - build_files(self._sys_argv_0_passthrough) + path.build(self._sys_argv_0_passthrough) build_backend = self.get_build_backend() with pytest.raises(AssertionError): build_backend.build_sdist("temp") + @pytest.mark.parametrize('build_hook', ('build_sdist', 'build_wheel')) + def test_build_with_empty_setuppy(self, build_backend, build_hook): + files = {'setup.py': ''} + path.build(files) + + with pytest.raises( + ValueError, + match=re.escape('No distribution was found.')): + getattr(build_backend, build_hook)("temp") + class TestBuildMetaLegacyBackend(TestBuildMetaBackend): backend_name = 'setuptools.build_meta:__legacy__' @@ -449,13 +460,13 @@ class TestBuildMetaLegacyBackend(TestBuildMetaBackend): # build_meta_legacy-specific tests def test_build_sdist_relative_path_import(self, tmpdir_cwd): # This must fail in build_meta, but must pass in build_meta_legacy - build_files(self._relative_path_import_files) + path.build(self._relative_path_import_files) build_backend = self.get_build_backend() build_backend.build_sdist("temp") def test_sys_argv_passthrough(self, tmpdir_cwd): - build_files(self._sys_argv_0_passthrough) + path.build(self._sys_argv_0_passthrough) build_backend = self.get_build_backend() build_backend.build_sdist("temp") diff --git a/setuptools/tests/test_config.py b/setuptools/tests/test_config.py index 1dee1271..005742e4 100644 --- a/setuptools/tests/test_config.py +++ b/setuptools/tests/test_config.py @@ -1,3 +1,6 @@ +import types +import sys + import contextlib import configparser @@ -7,6 +10,7 @@ from distutils.errors import DistutilsOptionError, DistutilsFileError from mock import patch from setuptools.dist import Distribution, _Distribution from setuptools.config import ConfigHandler, read_configuration +from distutils.core import Command from .textwrap import DALS @@ -26,14 +30,11 @@ def make_package_dir(name, base_dir, ns=False): def fake_env( - tmpdir, setup_cfg, setup_py=None, - encoding='ascii', package_path='fake_package'): + tmpdir, setup_cfg, setup_py=None, encoding='ascii', package_path='fake_package' +): if setup_py is None: - setup_py = ( - 'from setuptools import setup\n' - 'setup()\n' - ) + setup_py = 'from setuptools import setup\n' 'setup()\n' tmpdir.join('setup.py').write(setup_py) config = tmpdir.join('setup.cfg') @@ -74,7 +75,6 @@ def test_parsers_implemented(): class TestConfigurationReader: - def test_basic(self, tmpdir): _, config = fake_env( tmpdir, @@ -83,7 +83,7 @@ class TestConfigurationReader: 'keywords = one, two\n' '\n' '[options]\n' - 'scripts = bin/a.py, bin/b.py\n' + 'scripts = bin/a.py, bin/b.py\n', ) config_dict = read_configuration('%s' % config) assert config_dict['metadata']['version'] == '10.1.1' @@ -97,15 +97,12 @@ class TestConfigurationReader: def test_ignore_errors(self, tmpdir): _, config = fake_env( tmpdir, - '[metadata]\n' - 'version = attr: none.VERSION\n' - 'keywords = one, two\n' + '[metadata]\n' 'version = attr: none.VERSION\n' 'keywords = one, two\n', ) with pytest.raises(ImportError): read_configuration('%s' % config) - config_dict = read_configuration( - '%s' % config, ignore_option_errors=True) + config_dict = read_configuration('%s' % config, ignore_option_errors=True) assert config_dict['metadata']['keywords'] == ['one', 'two'] assert 'version' not in config_dict['metadata'] @@ -114,7 +111,6 @@ class TestConfigurationReader: class TestMetadata: - def test_basic(self, tmpdir): fake_env( @@ -129,7 +125,7 @@ class TestMetadata: 'provides = package, package.sub\n' 'license = otherlic\n' 'download_url = http://test.test.com/test/\n' - 'maintainer_email = test@test.com\n' + 'maintainer_email = test@test.com\n', ) tmpdir.join('README').write('readme contents\nline2') @@ -156,12 +152,14 @@ class TestMetadata: def test_license_cfg(self, tmpdir): fake_env( tmpdir, - DALS(""" + DALS( + """ [metadata] name=foo version=0.0.1 license=Apache 2.0 - """) + """ + ), ) with get_dist(tmpdir) as dist: @@ -175,9 +173,7 @@ class TestMetadata: fake_env( tmpdir, - '[metadata]\n' - 'long_description = file: README.rst, CHANGES.rst\n' - '\n' + '[metadata]\n' 'long_description = file: README.rst, CHANGES.rst\n' '\n', ) tmpdir.join('README.rst').write('readme contents\nline2') @@ -185,17 +181,12 @@ class TestMetadata: with get_dist(tmpdir) as dist: assert dist.metadata.long_description == ( - 'readme contents\nline2\n' - 'changelog contents\nand stuff' + 'readme contents\nline2\n' 'changelog contents\nand stuff' ) def test_file_sandboxed(self, tmpdir): - fake_env( - tmpdir, - '[metadata]\n' - 'long_description = file: ../../README\n' - ) + fake_env(tmpdir, '[metadata]\n' 'long_description = file: ../../README\n') with get_dist(tmpdir, parse=False) as dist: with pytest.raises(DistutilsOptionError): @@ -206,13 +197,13 @@ class TestMetadata: fake_env( tmpdir, '[metadata]\n' - 'author-email = test@test.com\n' - 'home-page = http://test.test.com/test/\n' + 'author_email = test@test.com\n' + 'home_page = http://test.test.com/test/\n' 'summary = Short summary\n' 'platform = a, b\n' 'classifier =\n' ' Framework :: Django\n' - ' Programming Language :: Python :: 3.5\n' + ' Programming Language :: Python :: 3.5\n', ) with get_dist(tmpdir) as dist: @@ -237,7 +228,7 @@ class TestMetadata: ' two\n' 'classifiers =\n' ' Framework :: Django\n' - ' Programming Language :: Python :: 3.5\n' + ' Programming Language :: Python :: 3.5\n', ) with get_dist(tmpdir) as dist: metadata = dist.metadata @@ -254,7 +245,7 @@ class TestMetadata: '[metadata]\n' 'project_urls =\n' ' Link One = https://example.com/one/\n' - ' Link Two = https://example.com/two/\n' + ' Link Two = https://example.com/two/\n', ) with get_dist(tmpdir) as dist: metadata = dist.metadata @@ -266,9 +257,7 @@ class TestMetadata: def test_version(self, tmpdir): package_dir, config = fake_env( - tmpdir, - '[metadata]\n' - 'version = attr: fake_package.VERSION\n' + tmpdir, '[metadata]\n' 'version = attr: fake_package.VERSION\n' ) sub_a = package_dir.mkdir('subpkg_a') @@ -278,37 +267,28 @@ class TestMetadata: sub_b = package_dir.mkdir('subpkg_b') sub_b.join('__init__.py').write('') sub_b.join('mod.py').write( - 'import third_party_module\n' - 'VERSION = (2016, 11, 26)' + 'import third_party_module\n' 'VERSION = (2016, 11, 26)' ) with get_dist(tmpdir) as dist: assert dist.metadata.version == '1.2.3' - config.write( - '[metadata]\n' - 'version = attr: fake_package.get_version\n' - ) + config.write('[metadata]\n' 'version = attr: fake_package.get_version\n') with get_dist(tmpdir) as dist: assert dist.metadata.version == '3.4.5.dev' - config.write( - '[metadata]\n' - 'version = attr: fake_package.VERSION_MAJOR\n' - ) + config.write('[metadata]\n' 'version = attr: fake_package.VERSION_MAJOR\n') with get_dist(tmpdir) as dist: assert dist.metadata.version == '1' config.write( - '[metadata]\n' - 'version = attr: fake_package.subpkg_a.mod.VERSION\n' + '[metadata]\n' 'version = attr: fake_package.subpkg_a.mod.VERSION\n' ) with get_dist(tmpdir) as dist: assert dist.metadata.version == '2016.11.26' config.write( - '[metadata]\n' - 'version = attr: fake_package.subpkg_b.mod.VERSION\n' + '[metadata]\n' 'version = attr: fake_package.subpkg_b.mod.VERSION\n' ) with get_dist(tmpdir) as dist: assert dist.metadata.version == '2016.11.26' @@ -316,9 +296,7 @@ class TestMetadata: def test_version_file(self, tmpdir): _, config = fake_env( - tmpdir, - '[metadata]\n' - 'version = file: fake_package/version.txt\n' + tmpdir, '[metadata]\n' 'version = file: fake_package/version.txt\n' ) tmpdir.join('fake_package', 'version.txt').write('1.2.3\n') @@ -339,7 +317,7 @@ class TestMetadata: '[options]\n' 'package_dir =\n' ' = src\n', - package_path='src/fake_package_simple' + package_path='src/fake_package_simple', ) with get_dist(tmpdir) as dist: @@ -354,7 +332,7 @@ class TestMetadata: '[options]\n' 'package_dir =\n' ' fake_package_rename = fake_dir\n', - package_path='fake_dir' + package_path='fake_dir', ) with get_dist(tmpdir) as dist: @@ -369,7 +347,7 @@ class TestMetadata: '[options]\n' 'package_dir =\n' ' fake_package_complex = src/fake_dir\n', - package_path='src/fake_dir' + package_path='src/fake_dir', ) with get_dist(tmpdir) as dist: @@ -377,39 +355,28 @@ class TestMetadata: def test_unknown_meta_item(self, tmpdir): - fake_env( - tmpdir, - '[metadata]\n' - 'name = fake_name\n' - 'unknown = some\n' - ) + fake_env(tmpdir, '[metadata]\n' 'name = fake_name\n' 'unknown = some\n') with get_dist(tmpdir, parse=False) as dist: dist.parse_config_files() # Skip unknown. def test_usupported_section(self, tmpdir): - fake_env( - tmpdir, - '[metadata.some]\n' - 'key = val\n' - ) + fake_env(tmpdir, '[metadata.some]\n' 'key = val\n') with get_dist(tmpdir, parse=False) as dist: with pytest.raises(DistutilsOptionError): dist.parse_config_files() def test_classifiers(self, tmpdir): - expected = set([ - 'Framework :: Django', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - ]) + expected = set( + [ + 'Framework :: Django', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + ] + ) # From file. - _, config = fake_env( - tmpdir, - '[metadata]\n' - 'classifiers = file: classifiers\n' - ) + _, config = fake_env(tmpdir, '[metadata]\n' 'classifiers = file: classifiers\n') tmpdir.join('classifiers').write( 'Framework :: Django\n' @@ -437,7 +404,7 @@ class TestMetadata: '[metadata]\n' 'version = 10.1.1\n' 'description = Some description\n' - 'requires = some, requirement\n' + 'requires = some, requirement\n', ) with pytest.deprecated_call(): @@ -449,41 +416,26 @@ class TestMetadata: assert metadata.requires == ['some', 'requirement'] def test_interpolation(self, tmpdir): - fake_env( - tmpdir, - '[metadata]\n' - 'description = %(message)s\n' - ) + fake_env(tmpdir, '[metadata]\n' 'description = %(message)s\n') with pytest.raises(configparser.InterpolationMissingOptionError): with get_dist(tmpdir): pass def test_non_ascii_1(self, tmpdir): - fake_env( - tmpdir, - '[metadata]\n' - 'description = éàïôñ\n', - encoding='utf-8' - ) + fake_env(tmpdir, '[metadata]\n' 'description = éàïôñ\n', encoding='utf-8') with get_dist(tmpdir): pass def test_non_ascii_3(self, tmpdir): - fake_env( - tmpdir, - '\n' - '# -*- coding: invalid\n' - ) + fake_env(tmpdir, '\n' '# -*- coding: invalid\n') with get_dist(tmpdir): pass def test_non_ascii_4(self, tmpdir): fake_env( tmpdir, - '# -*- coding: utf-8\n' - '[metadata]\n' - 'description = éàïôñ\n', - encoding='utf-8' + '# -*- coding: utf-8\n' '[metadata]\n' 'description = éàïôñ\n', + encoding='utf-8', ) with get_dist(tmpdir) as dist: assert dist.metadata.description == 'éàïôñ' @@ -497,29 +449,63 @@ class TestMetadata: '# vim: set fileencoding=iso-8859-15 :\n' '[metadata]\n' 'description = éàïôñ\n', - encoding='iso-8859-15' + encoding='iso-8859-15', ) with pytest.raises(UnicodeDecodeError): with get_dist(tmpdir): pass + def test_warn_dash_deprecation(self, tmpdir): + # warn_dash_deprecation() is a method in setuptools.dist + # remove this test and the method when no longer needed + fake_env( + tmpdir, + '[metadata]\n' + 'author-email = test@test.com\n' + 'maintainer_email = foo@foo.com\n', + ) + msg = ( + "Usage of dash-separated 'author-email' will not be supported " + "in future versions. " + "Please use the underscore name 'author_email' instead" + ) + with pytest.warns(UserWarning, match=msg): + with get_dist(tmpdir) as dist: + metadata = dist.metadata + + assert metadata.author_email == 'test@test.com' + assert metadata.maintainer_email == 'foo@foo.com' -class TestOptions: + def test_make_option_lowercase(self, tmpdir): + # remove this test and the method make_option_lowercase() in setuptools.dist + # when no longer needed + fake_env( + tmpdir, '[metadata]\n' 'Name = foo\n' 'description = Some description\n' + ) + msg = ( + "Usage of uppercase key 'Name' in 'metadata' will be deprecated in " + "future versions. " + "Please use lowercase 'name' instead" + ) + with pytest.warns(UserWarning, match=msg): + with get_dist(tmpdir) as dist: + metadata = dist.metadata + + assert metadata.name == 'foo' + assert metadata.description == 'Some description' + +class TestOptions: def test_basic(self, tmpdir): fake_env( tmpdir, '[options]\n' 'zip_safe = True\n' - 'use_2to3 = 1\n' 'include_package_data = yes\n' 'package_dir = b=c, =src\n' 'packages = pack_a, pack_b.subpack\n' 'namespace_packages = pack1, pack2\n' - 'use_2to3_fixers = your.fixers, or.here\n' - 'use_2to3_exclude_fixers = one.here, two.there\n' - 'convert_2to3_doctests = src/tests/one.txt, src/two.txt\n' 'scripts = bin/one.py, bin/two.py\n' 'eager_resources = bin/one.py, bin/two.py\n' 'install_requires = docutils>=0.3; pack ==1.1, ==1.3; hey\n' @@ -528,34 +514,24 @@ class TestOptions: 'dependency_links = http://some.com/here/1, ' 'http://some.com/there/2\n' 'python_requires = >=1.0, !=2.8\n' - 'py_modules = module1, module2\n' + 'py_modules = module1, module2\n', ) with get_dist(tmpdir) as dist: assert dist.zip_safe - assert dist.use_2to3 assert dist.include_package_data assert dist.package_dir == {'': 'src', 'b': 'c'} assert dist.packages == ['pack_a', 'pack_b.subpack'] assert dist.namespace_packages == ['pack1', 'pack2'] - assert dist.use_2to3_fixers == ['your.fixers', 'or.here'] - assert dist.use_2to3_exclude_fixers == ['one.here', 'two.there'] - assert dist.convert_2to3_doctests == ([ - 'src/tests/one.txt', 'src/two.txt']) assert dist.scripts == ['bin/one.py', 'bin/two.py'] - assert dist.dependency_links == ([ - 'http://some.com/here/1', - 'http://some.com/there/2' - ]) - assert dist.install_requires == ([ - 'docutils>=0.3', - 'pack==1.1,==1.3', - 'hey' - ]) - assert dist.setup_requires == ([ - 'docutils>=0.3', - 'spack ==1.1, ==1.3', - 'there' - ]) + assert dist.dependency_links == ( + ['http://some.com/here/1', 'http://some.com/there/2'] + ) + assert dist.install_requires == ( + ['docutils>=0.3', 'pack==1.1,==1.3', 'hey'] + ) + assert dist.setup_requires == ( + ['docutils>=0.3', 'spack ==1.1, ==1.3', 'there'] + ) assert dist.tests_require == ['mock==0.7.2', 'pytest'] assert dist.python_requires == '>=1.0, !=2.8' assert dist.py_modules == ['module1', 'module2'] @@ -573,15 +549,6 @@ class TestOptions: 'namespace_packages = \n' ' pack1\n' ' pack2\n' - 'use_2to3_fixers = \n' - ' your.fixers\n' - ' or.here\n' - 'use_2to3_exclude_fixers = \n' - ' one.here\n' - ' two.there\n' - 'convert_2to3_doctests = \n' - ' src/tests/one.txt\n' - ' src/two.txt\n' 'scripts = \n' ' bin/one.py\n' ' bin/two.py\n' @@ -601,39 +568,26 @@ class TestOptions: ' there\n' 'dependency_links = \n' ' http://some.com/here/1\n' - ' http://some.com/there/2\n' + ' http://some.com/there/2\n', ) with get_dist(tmpdir) as dist: assert dist.package_dir == {'': 'src', 'b': 'c'} assert dist.packages == ['pack_a', 'pack_b.subpack'] assert dist.namespace_packages == ['pack1', 'pack2'] - assert dist.use_2to3_fixers == ['your.fixers', 'or.here'] - assert dist.use_2to3_exclude_fixers == ['one.here', 'two.there'] - assert dist.convert_2to3_doctests == ( - ['src/tests/one.txt', 'src/two.txt']) assert dist.scripts == ['bin/one.py', 'bin/two.py'] - assert dist.dependency_links == ([ - 'http://some.com/here/1', - 'http://some.com/there/2' - ]) - assert dist.install_requires == ([ - 'docutils>=0.3', - 'pack==1.1,==1.3', - 'hey' - ]) - assert dist.setup_requires == ([ - 'docutils>=0.3', - 'spack ==1.1, ==1.3', - 'there' - ]) + assert dist.dependency_links == ( + ['http://some.com/here/1', 'http://some.com/there/2'] + ) + assert dist.install_requires == ( + ['docutils>=0.3', 'pack==1.1,==1.3', 'hey'] + ) + assert dist.setup_requires == ( + ['docutils>=0.3', 'spack ==1.1, ==1.3', 'there'] + ) assert dist.tests_require == ['mock==0.7.2', 'pytest'] def test_package_dir_fail(self, tmpdir): - fake_env( - tmpdir, - '[options]\n' - 'package_dir = a b\n' - ) + fake_env(tmpdir, '[options]\n' 'package_dir = a b\n') with get_dist(tmpdir, parse=False) as dist: with pytest.raises(DistutilsOptionError): dist.parse_config_files() @@ -647,7 +601,7 @@ class TestOptions: '\n' '[options.exclude_package_data]\n' '* = fake1.txt, fake2.txt\n' - 'hello = *.dat\n' + 'hello = *.dat\n', ) with get_dist(tmpdir) as dist: @@ -661,29 +615,21 @@ class TestOptions: } def test_packages(self, tmpdir): - fake_env( - tmpdir, - '[options]\n' - 'packages = find:\n' - ) + fake_env(tmpdir, '[options]\n' 'packages = find:\n') with get_dist(tmpdir) as dist: assert dist.packages == ['fake_package'] def test_find_directive(self, tmpdir): - dir_package, config = fake_env( - tmpdir, - '[options]\n' - 'packages = find:\n' - ) + dir_package, config = fake_env(tmpdir, '[options]\n' 'packages = find:\n') dir_sub_one, _ = make_package_dir('sub_one', dir_package) dir_sub_two, _ = make_package_dir('sub_two', dir_package) with get_dist(tmpdir) as dist: - assert set(dist.packages) == set([ - 'fake_package', 'fake_package.sub_two', 'fake_package.sub_one' - ]) + assert set(dist.packages) == set( + ['fake_package', 'fake_package.sub_two', 'fake_package.sub_one'] + ) config.write( '[options]\n' @@ -707,14 +653,11 @@ class TestOptions: ' fake_package.sub_one\n' ) with get_dist(tmpdir) as dist: - assert set(dist.packages) == set( - ['fake_package', 'fake_package.sub_two']) + assert set(dist.packages) == set(['fake_package', 'fake_package.sub_two']) def test_find_namespace_directive(self, tmpdir): dir_package, config = fake_env( - tmpdir, - '[options]\n' - 'packages = find_namespace:\n' + tmpdir, '[options]\n' 'packages = find_namespace:\n' ) dir_sub_one, _ = make_package_dir('sub_one', dir_package) @@ -722,7 +665,9 @@ class TestOptions: with get_dist(tmpdir) as dist: assert set(dist.packages) == { - 'fake_package', 'fake_package.sub_two', 'fake_package.sub_one' + 'fake_package', + 'fake_package.sub_two', + 'fake_package.sub_one', } config.write( @@ -747,9 +692,7 @@ class TestOptions: ' fake_package.sub_one\n' ) with get_dist(tmpdir) as dist: - assert set(dist.packages) == { - 'fake_package', 'fake_package.sub_two' - } + assert set(dist.packages) == {'fake_package', 'fake_package.sub_two'} def test_extras_require(self, tmpdir): fake_env( @@ -758,23 +701,29 @@ class TestOptions: 'pdf = ReportLab>=1.2; RXP\n' 'rest = \n' ' docutils>=0.3\n' - ' pack ==1.1, ==1.3\n' + ' pack ==1.1, ==1.3\n', ) with get_dist(tmpdir) as dist: assert dist.extras_require == { 'pdf': ['ReportLab>=1.2', 'RXP'], - 'rest': ['docutils>=0.3', 'pack==1.1,==1.3'] + 'rest': ['docutils>=0.3', 'pack==1.1,==1.3'], } assert dist.metadata.provides_extras == set(['pdf', 'rest']) + def test_dash_preserved_extras_require(self, tmpdir): + fake_env(tmpdir, '[options.extras_require]\n' 'foo-a = foo\n' 'foo_b = test\n') + + with get_dist(tmpdir) as dist: + assert dist.extras_require == {'foo-a': ['foo'], 'foo_b': ['test']} + def test_entry_points(self, tmpdir): _, config = fake_env( tmpdir, '[options.entry_points]\n' 'group1 = point1 = pack.module:func, ' '.point2 = pack.module2:func_rest [rest]\n' - 'group2 = point3 = pack.module:func2\n' + 'group2 = point3 = pack.module:func2\n', ) with get_dist(tmpdir) as dist: @@ -783,7 +732,7 @@ class TestOptions: 'point1 = pack.module:func', '.point2 = pack.module2:func_rest [rest]', ], - 'group2': ['point3 = pack.module:func2'] + 'group2': ['point3 = pack.module:func2'], } expected = ( @@ -794,14 +743,29 @@ class TestOptions: tmpdir.join('entry_points').write(expected) # From file. - config.write( - '[options]\n' - 'entry_points = file: entry_points\n' - ) + config.write('[options]\n' 'entry_points = file: entry_points\n') with get_dist(tmpdir) as dist: assert dist.entry_points == expected + def test_case_sensitive_entry_points(self, tmpdir): + _, config = fake_env( + tmpdir, + '[options.entry_points]\n' + 'GROUP1 = point1 = pack.module:func, ' + '.point2 = pack.module2:func_rest [rest]\n' + 'group2 = point3 = pack.module:func2\n', + ) + + with get_dist(tmpdir) as dist: + assert dist.entry_points == { + 'GROUP1': [ + 'point1 = pack.module:func', + '.point2 = pack.module2:func_rest [rest]', + ], + 'group2': ['point3 = pack.module:func2'], + } + def test_data_files(self, tmpdir): fake_env( tmpdir, @@ -809,7 +773,7 @@ class TestOptions: 'cfg =\n' ' a/b.conf\n' ' c/d.conf\n' - 'data = e/f.dat, g/h.dat\n' + 'data = e/f.dat, g/h.dat\n', ) with get_dist(tmpdir) as dist: @@ -819,13 +783,50 @@ class TestOptions: ] assert sorted(dist.data_files) == sorted(expected) + def test_data_files_globby(self, tmpdir): + fake_env( + tmpdir, + '[options.data_files]\n' + 'cfg =\n' + ' a/b.conf\n' + ' c/d.conf\n' + 'data = *.dat\n' + 'icons = \n' + ' *.ico\n' + 'audio = \n' + ' *.wav\n' + ' sounds.db\n' + ) + + # Create dummy files for glob()'s sake: + tmpdir.join('a.dat').write('') + tmpdir.join('b.dat').write('') + tmpdir.join('c.dat').write('') + tmpdir.join('a.ico').write('') + tmpdir.join('b.ico').write('') + tmpdir.join('c.ico').write('') + tmpdir.join('beep.wav').write('') + tmpdir.join('boop.wav').write('') + tmpdir.join('sounds.db').write('') + + with get_dist(tmpdir) as dist: + expected = [ + ('cfg', ['a/b.conf', 'c/d.conf']), + ('data', ['a.dat', 'b.dat', 'c.dat']), + ('icons', ['a.ico', 'b.ico', 'c.ico']), + ('audio', ['beep.wav', 'boop.wav', 'sounds.db']), + ] + assert sorted(dist.data_files) == sorted(expected) + def test_python_requires_simple(self, tmpdir): fake_env( tmpdir, - DALS(""" + DALS( + """ [options] python_requires=>=2.7 - """), + """ + ), ) with get_dist(tmpdir) as dist: dist.parse_config_files() @@ -833,10 +834,12 @@ class TestOptions: def test_python_requires_compound(self, tmpdir): fake_env( tmpdir, - DALS(""" + DALS( + """ [options] python_requires=>=2.7,!=3.0.* - """), + """ + ), ) with get_dist(tmpdir) as dist: dist.parse_config_files() @@ -844,15 +847,35 @@ class TestOptions: def test_python_requires_invalid(self, tmpdir): fake_env( tmpdir, - DALS(""" + DALS( + """ [options] python_requires=invalid - """), + """ + ), ) with pytest.raises(Exception): with get_dist(tmpdir) as dist: dist.parse_config_files() + def test_cmdclass(self, tmpdir): + class CustomCmd(Command): + pass + + m = types.ModuleType('custom_build', 'test package') + + m.__dict__['CustomCmd'] = CustomCmd + + sys.modules['custom_build'] = m + + fake_env( + tmpdir, + '[options]\n' 'cmdclass =\n' ' customcmd = custom_build.CustomCmd\n', + ) + + with get_dist(tmpdir) as dist: + assert dist.cmdclass == {'customcmd': CustomCmd} + saved_dist_init = _Distribution.__init__ @@ -871,24 +894,23 @@ class TestExternalSetters: def _fake_distribution_init(self, dist, attrs): saved_dist_init(dist, attrs) # see self._DISTUTUILS_UNSUPPORTED_METADATA - setattr(dist.metadata, 'long_description_content_type', - 'text/something') + setattr(dist.metadata, 'long_description_content_type', 'text/something') # Test overwrite setup() args - setattr(dist.metadata, 'project_urls', { - 'Link One': 'https://example.com/one/', - 'Link Two': 'https://example.com/two/', - }) + setattr( + dist.metadata, + 'project_urls', + { + 'Link One': 'https://example.com/one/', + 'Link Two': 'https://example.com/two/', + }, + ) return None @patch.object(_Distribution, '__init__', autospec=True) def test_external_setters(self, mock_parent_init, tmpdir): mock_parent_init.side_effect = self._fake_distribution_init - dist = Distribution(attrs={ - 'project_urls': { - 'will_be': 'ignored' - } - }) + dist = Distribution(attrs={'project_urls': {'will_be': 'ignored'}}) assert dist.metadata.long_description_content_type == 'text/something' assert dist.metadata.project_urls == { diff --git a/setuptools/tests/test_develop.py b/setuptools/tests/test_develop.py index 2766da2f..70c5794c 100644 --- a/setuptools/tests/test_develop.py +++ b/setuptools/tests/test_develop.py @@ -2,11 +2,11 @@ """ import os -import site import sys -import io import subprocess import platform +import pathlib +import textwrap from setuptools.command import test @@ -14,7 +14,6 @@ import pytest from setuptools.command.develop import develop from setuptools.dist import Distribution -from setuptools.tests import ack_2to3 from . import contexts from . import namespaces @@ -23,7 +22,6 @@ from setuptools import setup setup(name='foo', packages=['foo'], - use_2to3=True, ) """ @@ -60,43 +58,6 @@ class TestDevelop: in_virtualenv = hasattr(sys, 'real_prefix') in_venv = hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix - @pytest.mark.skipif( - in_virtualenv or in_venv, - reason="Cannot run when invoked in a virtualenv or venv") - @ack_2to3 - def test_2to3_user_mode(self, test_env): - settings = dict( - name='foo', - packages=['foo'], - use_2to3=True, - version='0.0', - ) - dist = Distribution(settings) - dist.script_name = 'setup.py' - cmd = develop(dist) - cmd.user = 1 - cmd.ensure_finalized() - cmd.install_dir = site.USER_SITE - cmd.user = 1 - with contexts.quiet(): - cmd.run() - - # let's see if we got our egg link at the right place - content = os.listdir(site.USER_SITE) - content.sort() - assert content == ['easy-install.pth', 'foo.egg-link'] - - # Check that we are using the right code. - fn = os.path.join(site.USER_SITE, 'foo.egg-link') - with io.open(fn) as egg_link_file: - path = egg_link_file.read().split()[0].strip() - fn = os.path.join(path, 'foo', '__init__.py') - with io.open(fn) as init_file: - init = init_file.read().strip() - - expected = 'print("foo")' - assert init == expected - def test_console_scripts(self, tmpdir): """ Test that console scripts are installed and that they reference @@ -104,7 +65,8 @@ class TestDevelop: """ pytest.skip( "TODO: needs a fixture to cause 'develop' " - "to be invoked without mutating environment.") + "to be invoked without mutating environment." + ) settings = dict( name='foo', packages=['foo'], @@ -130,6 +92,7 @@ class TestResolver: of what _resolve_setup_path is intending to do. Come up with more meaningful cases that look like real-world scenarios. """ + def test_resolve_setup_path_cwd(self): assert develop._resolve_setup_path('.', '.', '.') == '.' @@ -141,7 +104,6 @@ class TestResolver: class TestNamespaces: - @staticmethod def install_develop(src_dir, target): @@ -149,7 +111,8 @@ class TestNamespaces: sys.executable, 'setup.py', 'develop', - '--install-dir', str(target), + '--install-dir', + str(target), ] with src_dir.as_cwd(): with test.test.paths_on_pythonpath([str(target)]): @@ -180,14 +143,16 @@ class TestNamespaces: 'pip', 'install', str(pkg_A), - '-t', str(target), + '-t', + str(target), ] subprocess.check_call(install_cmd) self.install_develop(pkg_B, target) namespaces.make_site_dir(target) try_import = [ sys.executable, - '-c', 'import myns.pkgA; import myns.pkgB', + '-c', + 'import myns.pkgA; import myns.pkgB', ] with test.test.paths_on_pythonpath([str(target)]): subprocess.check_call(try_import) @@ -195,7 +160,65 @@ class TestNamespaces: # additionally ensure that pkg_resources import works pkg_resources_imp = [ sys.executable, - '-c', 'import pkg_resources', + '-c', + 'import pkg_resources', ] with test.test.paths_on_pythonpath([str(target)]): subprocess.check_call(pkg_resources_imp) + + @staticmethod + def install_workaround(site_packages): + site_packages.mkdir(parents=True) + sc = site_packages / 'sitecustomize.py' + sc.write_text( + textwrap.dedent( + """ + import site + import pathlib + here = pathlib.Path(__file__).parent + site.addsitedir(str(here)) + """ + ).lstrip() + ) + + @pytest.mark.xfail( + platform.python_implementation() == 'PyPy', + reason="Workaround fails on PyPy (why?)", + ) + def test_editable_prefix(self, tmp_path, sample_project): + """ + Editable install to a prefix should be discoverable. + """ + prefix = tmp_path / 'prefix' + prefix.mkdir() + + # figure out where pip will likely install the package + site_packages = prefix / next( + pathlib.Path(path).relative_to(sys.prefix) + for path in sys.path + if 'site-packages' in path and path.startswith(sys.prefix) + ) + + # install the workaround + self.install_workaround(site_packages) + + env = dict(os.environ, PYTHONPATH=str(site_packages)) + cmd = [ + sys.executable, + '-m', + 'pip', + 'install', + '--editable', + str(sample_project), + '--prefix', + str(prefix), + '--no-build-isolation', + ] + subprocess.check_call(cmd, env=env) + + # now run 'sample' with the prefix on the PYTHONPATH + bin = 'Scripts' if platform.system() == 'Windows' else 'bin' + exe = prefix / bin / 'sample' + if sys.version_info < (3, 7) and platform.system() == 'Windows': + exe = str(exe) + subprocess.check_call([exe], env=env) diff --git a/setuptools/tests/test_dist.py b/setuptools/tests/test_dist.py index cb47fb58..c4279f0b 100644 --- a/setuptools/tests/test_dist.py +++ b/setuptools/tests/test_dist.py @@ -9,6 +9,9 @@ from setuptools.dist import ( _get_unpatched, check_package_data, DistDeprecationWarning, + check_specifier, + rfc822_escape, + rfc822_unescape, ) from setuptools import sic from setuptools import Distribution @@ -81,11 +84,8 @@ def __read_test_cases(): test_cases = [ ('Metadata version 1.0', params()), - ('Metadata version 1.1: Provides', params( - provides=['package'], - )), - ('Metadata version 1.1: Obsoletes', params( - obsoletes=['foo'], + ('Metadata Version 1.0: Short long description', params( + long_description='Short long description', )), ('Metadata version 1.1: Classifiers', params( classifiers=[ @@ -110,6 +110,10 @@ def __read_test_cases(): ('Metadata Version 2.1: Long Description Content Type', params( long_description_content_type='text/x-rst; charset=UTF-8', )), + ('License', params(license='MIT', )), + ('License multiline', params( + license='This is a long license \nover multiple lines', + )), pytest.param( 'Metadata Version 2.1: Provides Extra', params(provides_extras=['foo', 'bar']), @@ -161,6 +165,7 @@ def test_read_metadata(name, attrs): ('metadata_version', dist_class.get_metadata_version), ('provides', dist_class.get_provides), ('description', dist_class.get_description), + ('long_description', dist_class.get_long_description), ('download_url', dist_class.get_download_url), ('keywords', dist_class.get_keywords), ('platforms', dist_class.get_platforms), @@ -246,8 +251,8 @@ def test_maintainer_author(name, attrs, tmpdir): with io.open(str(fn.join('PKG-INFO')), 'r', encoding='utf-8') as f: raw_pkg_lines = f.readlines() - # Drop blank lines - pkg_lines = list(filter(None, raw_pkg_lines)) + # Drop blank lines and strip lines from default description + pkg_lines = list(filter(None, raw_pkg_lines[:-2])) pkg_lines_set = set(pkg_lines) @@ -323,3 +328,49 @@ def test_check_package_data(package_data, expected_message): with pytest.raises( DistutilsSetupError, match=re.escape(expected_message)): check_package_data(None, str('package_data'), package_data) + + +def test_check_specifier(): + # valid specifier value + attrs = {'name': 'foo', 'python_requires': '>=3.0, !=3.1'} + dist = Distribution(attrs) + check_specifier(dist, attrs, attrs['python_requires']) + + # invalid specifier value + attrs = {'name': 'foo', 'python_requires': ['>=3.0', '!=3.1']} + with pytest.raises(DistutilsSetupError): + dist = Distribution(attrs) + + +@pytest.mark.parametrize( + 'content, result', + ( + pytest.param( + "Just a single line", + None, + id="single_line", + ), + pytest.param( + "Multiline\nText\nwithout\nextra indents\n", + None, + id="multiline", + ), + pytest.param( + "Multiline\n With\n\nadditional\n indentation", + None, + id="multiline_with_indentation", + ), + pytest.param( + " Leading whitespace", + "Leading whitespace", + id="remove_leading_whitespace", + ), + pytest.param( + " Leading whitespace\nIn\n Multiline comment", + "Leading whitespace\nIn\n Multiline comment", + id="remove_leading_whitespace_multiline", + ), + ) +) +def test_rfc822_unescape(content, result): + assert (result or content) == rfc822_unescape(rfc822_escape(content)) diff --git a/setuptools/tests/test_distutils_adoption.py b/setuptools/tests/test_distutils_adoption.py index a53773df..0e89921c 100644 --- a/setuptools/tests/test_distutils_adoption.py +++ b/setuptools/tests/test_distutils_adoption.py @@ -21,10 +21,10 @@ class VirtualEnv(jaraco.envs.VirtualEnv): @pytest.fixture -def venv(tmpdir): +def venv(tmp_path, tmp_src): env = VirtualEnv() - env.root = path.Path(tmpdir) - env.req = os.getcwd() + env.root = path.Path(tmp_path / 'venv') + env.req = str(tmp_src) return env.create() diff --git a/setuptools/tests/test_easy_install.py b/setuptools/tests/test_easy_install.py index dc00e697..6840d03b 100644 --- a/setuptools/tests/test_easy_install.py +++ b/setuptools/tests/test_easy_install.py @@ -15,8 +15,11 @@ import zipfile import mock import time import re +import subprocess +import pathlib import pytest +from jaraco import path from setuptools import sandbox from setuptools.sandbox import run_setup @@ -25,7 +28,6 @@ from setuptools.command.easy_install import ( EasyInstallDeprecationWarning, ScriptWriter, PthDistributions, WindowsScriptWriter, ) -from setuptools.command import easy_install as easy_install_pkg from setuptools.dist import Distribution from pkg_resources import normalize_path, working_set from pkg_resources import Distribution as PRDistribution @@ -34,10 +36,19 @@ from setuptools.tests import fail_on_ascii import pkg_resources from . import contexts -from .files import build_files from .textwrap import DALS +@pytest.fixture(autouse=True) +def pip_disable_index(monkeypatch): + """ + Important: Disable the default index for pip to avoid + querying packages in the index and potentially resolving + and installing packages there. + """ + monkeypatch.setenv('PIP_NO_INDEX', 'true') + + class FakeDist: def get_entry_map(self, group): if group != 'console_scripts': @@ -445,22 +456,22 @@ class TestSetupRequires: """ monkeypatch.setenv(str('PIP_RETRIES'), str('0')) monkeypatch.setenv(str('PIP_TIMEOUT'), str('0')) + monkeypatch.setenv('PIP_NO_INDEX', 'false') with contexts.quiet(): # create an sdist that has a build-time dependency. with TestSetupRequires.create_sdist() as dist_file: with contexts.tempdir() as temp_install_dir: with contexts.environment(PYTHONPATH=temp_install_dir): - ei_params = [ + cmd = [ + sys.executable, + '-m', 'setup', + 'easy_install', '--index-url', mock_index.url, '--exclude-scripts', '--install-dir', temp_install_dir, dist_file, ] - with sandbox.save_argv(['easy_install']): - # attempt to install the dist. It should - # fail because it doesn't exist. - with pytest.raises(SystemExit): - easy_install_pkg.main(ei_params) + subprocess.Popen(cmd).wait() # there should have been one requests to the server assert [r.path for r in mock_index.requests] == ['/does-not-exist/'] @@ -618,6 +629,7 @@ class TestSetupRequires: def test_setup_requires_honors_pip_env(self, mock_index, monkeypatch): monkeypatch.setenv(str('PIP_RETRIES'), str('0')) monkeypatch.setenv(str('PIP_TIMEOUT'), str('0')) + monkeypatch.setenv('PIP_NO_INDEX', 'false') monkeypatch.setenv(str('PIP_INDEX_URL'), mock_index.url) with contexts.save_pkg_resources_state(): with contexts.tempdir() as temp_dir: @@ -648,7 +660,7 @@ class TestSetupRequires: dep_url = path_to_url(dep_sdist, authority='localhost') test_pkg = create_setup_requires_package( temp_dir, - # Ignored (overriden by setup_attrs) + # Ignored (overridden by setup_attrs) 'python-xlib', '0.19', setup_attrs=dict( setup_requires='dependency @ %s' % dep_url)) @@ -658,26 +670,24 @@ class TestSetupRequires: def test_setup_requires_with_allow_hosts(self, mock_index): ''' The `allow-hosts` option in not supported anymore. ''' + files = { + 'test_pkg': { + 'setup.py': DALS(''' + from setuptools import setup + setup(setup_requires='python-xlib') + '''), + 'setup.cfg': DALS(''' + [easy_install] + allow_hosts = * + '''), + } + } with contexts.save_pkg_resources_state(): with contexts.tempdir() as temp_dir: - test_pkg = os.path.join(temp_dir, 'test_pkg') - test_setup_py = os.path.join(test_pkg, 'setup.py') - test_setup_cfg = os.path.join(test_pkg, 'setup.cfg') - os.mkdir(test_pkg) - with open(test_setup_py, 'w') as fp: - fp.write(DALS( - ''' - from setuptools import setup - setup(setup_requires='python-xlib') - ''')) - with open(test_setup_cfg, 'w') as fp: - fp.write(DALS( - ''' - [easy_install] - allow_hosts = * - ''')) + path.build(files, prefix=temp_dir) + setup_py = str(pathlib.Path(temp_dir, 'test_pkg', 'setup.py')) with pytest.raises(distutils.errors.DistutilsError): - run_setup(test_setup_py, [str('--version')]) + run_setup(setup_py, [str('--version')]) assert len(mock_index.requests) == 0 def test_setup_requires_with_python_requires(self, monkeypatch, tmpdir): @@ -720,7 +730,7 @@ class TestSetupRequires: with contexts.save_pkg_resources_state(): test_pkg = create_setup_requires_package( str(tmpdir), - 'python-xlib', '0.19', # Ignored (overriden by setup_attrs). + 'python-xlib', '0.19', # Ignored (overridden by setup_attrs). setup_attrs=dict( setup_requires='dep', dependency_links=[index_url])) test_setup_py = os.path.join(test_pkg, 'setup.py') @@ -730,10 +740,10 @@ class TestSetupRequires: assert eggs == ['dep 1.0'] @pytest.mark.parametrize( - 'use_legacy_installer,with_dependency_links_in_setup_py', - itertools.product((False, True), (False, True))) + 'with_dependency_links_in_setup_py', + (False, True)) def test_setup_requires_with_find_links_in_setup_cfg( - self, monkeypatch, use_legacy_installer, + self, monkeypatch, with_dependency_links_in_setup_py): monkeypatch.setenv(str('PIP_RETRIES'), str('0')) monkeypatch.setenv(str('PIP_TIMEOUT'), str('0')) @@ -755,11 +765,9 @@ class TestSetupRequires: fp.write(DALS( ''' from setuptools import installer, setup - if {use_legacy_installer}: - installer.fetch_build_egg = installer._legacy_fetch_build_egg setup(setup_requires='python-xlib==42', dependency_links={dependency_links!r}) - ''').format(use_legacy_installer=use_legacy_installer, # noqa + ''').format( dependency_links=dependency_links)) with open(test_setup_cfg, 'w') as fp: fp.write(DALS( @@ -785,7 +793,7 @@ class TestSetupRequires: # Create source tree for `dep`. dep_pkg = os.path.join(temp_dir, 'dep') os.mkdir(dep_pkg) - build_files({ + path.build({ 'setup.py': DALS(""" import setuptools diff --git a/setuptools/tests/test_egg_info.py b/setuptools/tests/test_egg_info.py index 1047468b..ee07b5a1 100644 --- a/setuptools/tests/test_egg_info.py +++ b/setuptools/tests/test_egg_info.py @@ -5,16 +5,17 @@ import glob import re import stat import time +from typing import List, Tuple + +import pytest +from jaraco import path from setuptools.command.egg_info import ( egg_info, manifest_maker, EggInfoDeprecationWarning, get_pkg_info_revision, ) from setuptools.dist import Distribution -import pytest - from . import environment -from .files import build_files from .textwrap import DALS from . import contexts @@ -37,7 +38,7 @@ class TestEggInfo: """) def _create_project(self): - build_files({ + path.build({ 'setup.py': self.setup_script, 'hello.py': DALS(""" def run(): @@ -45,6 +46,11 @@ class TestEggInfo: """) }) + @staticmethod + def _extract_mv_version(pkg_info_lines: List[str]) -> Tuple[int, int]: + version_str = pkg_info_lines[0].split(' ')[1] + return tuple(map(int, version_str.split('.')[:2])) + @pytest.fixture def env(self): with contexts.tempdir(prefix='setuptools-test.') as env_dir: @@ -56,7 +62,7 @@ class TestEggInfo: for dirname in subs ) list(map(os.mkdir, env.paths.values())) - build_files({ + path.build({ env.paths['home']: { '.pydistutils.cfg': DALS(""" [egg_info] @@ -106,7 +112,7 @@ class TestEggInfo: the file should remain unchanged. """ setup_cfg = os.path.join(env.paths['home'], 'setup.cfg') - build_files({ + path.build({ setup_cfg: DALS(""" [egg_info] tag_build = @@ -159,8 +165,10 @@ class TestEggInfo: setup() """) - build_files({'setup.py': setup_script, - 'setup.cfg': setup_config}) + path.build({ + 'setup.py': setup_script, + 'setup.cfg': setup_config, + }) # This command should fail with a ValueError, but because it's # currently configured to use a subprocess, the actual traceback @@ -193,7 +201,7 @@ class TestEggInfo: def test_manifest_template_is_read(self, tmpdir_cwd, env): self._create_project() - build_files({ + path.build({ 'MANIFEST.in': DALS(""" recursive-include docs *.rst """), @@ -216,8 +224,10 @@ class TestEggInfo: ''' ) % ('' if use_setup_cfg else requires) setup_config = requires if use_setup_cfg else '' - build_files({'setup.py': setup_script, - 'setup.cfg': setup_config}) + path.build({ + 'setup.py': setup_script, + 'setup.cfg': setup_config, + }) mismatch_marker = "python_version<'{this_ver}'".format( this_ver=sys.version_info[0], @@ -533,7 +543,7 @@ class TestEggInfo: 'setup.cfg': DALS(""" """), 'LICENSE': "Test license" - }, False), # no license_file attribute + }, True), # no license_file attribute, LICENSE auto-included ({ 'setup.cfg': DALS(""" [metadata] @@ -541,12 +551,20 @@ class TestEggInfo: """), 'MANIFEST.in': "exclude LICENSE", 'LICENSE': "Test license" - }, False) # license file is manually excluded + }, True), # manifest is overwritten by license_file + pytest.param({ + 'setup.cfg': DALS(""" + [metadata] + license_file = LICEN[CS]E* + """), + 'LICENSE': "Test license", + }, True, + id="glob_pattern"), ]) def test_setup_cfg_license_file( self, tmpdir_cwd, env, files, license_in_sources): self._create_project() - build_files(files) + path.build(files) environment.run_setup_py( cmd=['egg_info'], @@ -621,7 +639,7 @@ class TestEggInfo: 'setup.cfg': DALS(""" """), 'LICENSE': "Test license" - }, [], ['LICENSE']), # no license_files attribute + }, ['LICENSE'], []), # no license_files attribute, LICENSE auto-included ({ 'setup.cfg': DALS(""" [metadata] @@ -629,7 +647,7 @@ class TestEggInfo: """), 'MANIFEST.in': "exclude LICENSE", 'LICENSE': "Test license" - }, [], ['LICENSE']), # license file is manually excluded + }, ['LICENSE'], []), # manifest is overwritten by license_files ({ 'setup.cfg': DALS(""" [metadata] @@ -640,12 +658,53 @@ class TestEggInfo: 'MANIFEST.in': "exclude LICENSE-XYZ", 'LICENSE-ABC': "ABC license", 'LICENSE-XYZ': "XYZ license" - }, ['LICENSE-ABC'], ['LICENSE-XYZ']) # subset is manually excluded + # manifest is overwritten by license_files + }, ['LICENSE-ABC', 'LICENSE-XYZ'], []), + pytest.param({ + 'setup.cfg': "", + 'LICENSE-ABC': "ABC license", + 'COPYING-ABC': "ABC copying", + 'NOTICE-ABC': "ABC notice", + 'AUTHORS-ABC': "ABC authors", + 'LICENCE-XYZ': "XYZ license", + 'LICENSE': "License", + 'INVALID-LICENSE': "Invalid license", + }, [ + 'LICENSE-ABC', + 'COPYING-ABC', + 'NOTICE-ABC', + 'AUTHORS-ABC', + 'LICENCE-XYZ', + 'LICENSE', + ], ['INVALID-LICENSE'], + # ('LICEN[CS]E*', 'COPYING*', 'NOTICE*', 'AUTHORS*') + id="default_glob_patterns"), + pytest.param({ + 'setup.cfg': DALS(""" + [metadata] + license_files = + LICENSE* + """), + 'LICENSE-ABC': "ABC license", + 'NOTICE-XYZ': "XYZ notice", + }, ['LICENSE-ABC'], ['NOTICE-XYZ'], + id="no_default_glob_patterns"), + pytest.param({ + 'setup.cfg': DALS(""" + [metadata] + license_files = + LICENSE-ABC + LICENSE* + """), + 'LICENSE-ABC': "ABC license", + }, ['LICENSE-ABC'], [], + id="files_only_added_once", + ), ]) def test_setup_cfg_license_files( self, tmpdir_cwd, env, files, incl_licenses, excl_licenses): self._create_project() - build_files(files) + path.build(files) environment.run_setup_py( cmd=['egg_info'], @@ -744,13 +803,34 @@ class TestEggInfo: 'LICENSE-ABC': "ABC license", 'LICENSE-PQR': "PQR license", 'LICENSE-XYZ': "XYZ license" - # manually excluded - }, ['LICENSE-XYZ'], ['LICENSE-ABC', 'LICENSE-PQR']) + # manifest is overwritten + }, ['LICENSE-ABC', 'LICENSE-PQR', 'LICENSE-XYZ'], []), + pytest.param({ + 'setup.cfg': DALS(""" + [metadata] + license_file = LICENSE* + """), + 'LICENSE-ABC': "ABC license", + 'NOTICE-XYZ': "XYZ notice", + }, ['LICENSE-ABC'], ['NOTICE-XYZ'], + id="no_default_glob_patterns"), + pytest.param({ + 'setup.cfg': DALS(""" + [metadata] + license_file = LICENSE* + license_files = + NOTICE* + """), + 'LICENSE-ABC': "ABC license", + 'NOTICE-ABC': "ABC notice", + 'AUTHORS-ABC': "ABC authors", + }, ['LICENSE-ABC', 'NOTICE-ABC'], ['AUTHORS-ABC'], + id="combined_glob_patterrns"), ]) def test_setup_cfg_license_file_license_files( self, tmpdir_cwd, env, files, incl_licenses, excl_licenses): self._create_project() - build_files(files) + path.build(files) environment.run_setup_py( cmd=['egg_info'], @@ -767,6 +847,52 @@ class TestEggInfo: for lf in excl_licenses: assert sources_lines.count(lf) == 0 + def test_license_file_attr_pkg_info(self, tmpdir_cwd, env): + """All matched license files should have a corresponding License-File.""" + self._create_project() + path.build({ + "setup.cfg": DALS(""" + [metadata] + license_files = + NOTICE* + LICENSE* + """), + "LICENSE-ABC": "ABC license", + "LICENSE-XYZ": "XYZ license", + "NOTICE": "included", + "IGNORE": "not include", + }) + + environment.run_setup_py( + cmd=['egg_info'], + pypath=os.pathsep.join([env.paths['lib'], str(tmpdir_cwd)]) + ) + egg_info_dir = os.path.join('.', 'foo.egg-info') + with open(os.path.join(egg_info_dir, 'PKG-INFO')) as pkginfo_file: + pkg_info_lines = pkginfo_file.read().split('\n') + license_file_lines = [ + line for line in pkg_info_lines if line.startswith('License-File:')] + + # Only 'NOTICE', LICENSE-ABC', and 'LICENSE-XYZ' should have been matched + # Also assert that order from license_files is keeped + assert "License-File: NOTICE" == license_file_lines[0] + assert "License-File: LICENSE-ABC" in license_file_lines[1:] + assert "License-File: LICENSE-XYZ" in license_file_lines[1:] + + def test_metadata_version(self, tmpdir_cwd, env): + """Make sure latest metadata version is used by default.""" + self._setup_script_with_requires("") + code, data = environment.run_setup_py( + cmd=['egg_info'], + pypath=os.pathsep.join([env.paths['lib'], str(tmpdir_cwd)]), + data_stream=1, + ) + egg_info_dir = os.path.join('.', 'foo.egg-info') + with open(os.path.join(egg_info_dir, 'PKG-INFO')) as pkginfo_file: + pkg_info_lines = pkginfo_file.read().split('\n') + # Update metadata version if changed + assert self._extract_mv_version(pkg_info_lines) == (2, 1) + def test_long_description_content_type(self, tmpdir_cwd, env): # Test that specifying a `long_description_content_type` keyword arg to # the `setup` function results in writing a `Description-Content-Type` @@ -793,6 +919,29 @@ class TestEggInfo: assert expected_line in pkg_info_lines assert 'Metadata-Version: 2.1' in pkg_info_lines + def test_long_description(self, tmpdir_cwd, env): + # Test that specifying `long_description` and `long_description_content_type` + # keyword args to the `setup` function results in writing + # the description in the message payload of the `PKG-INFO` file + # in the `<distribution>.egg-info` directory. + self._setup_script_with_requires( + "long_description='This is a long description\\nover multiple lines'," + "long_description_content_type='text/markdown'," + ) + code, data = environment.run_setup_py( + cmd=['egg_info'], + pypath=os.pathsep.join([env.paths['lib'], str(tmpdir_cwd)]), + data_stream=1, + ) + egg_info_dir = os.path.join('.', 'foo.egg-info') + with open(os.path.join(egg_info_dir, 'PKG-INFO')) as pkginfo_file: + pkg_info_lines = pkginfo_file.read().split('\n') + assert 'Metadata-Version: 2.1' in pkg_info_lines + assert '' == pkg_info_lines[-1] # last line should be empty + long_desc_lines = pkg_info_lines[pkg_info_lines.index(''):] + assert 'This is a long description' in long_desc_lines + assert 'over multiple lines' in long_desc_lines + def test_project_urls(self, tmpdir_cwd, env): # Test that specifying a `project_urls` dict to the `setup` # function results in writing multiple `Project-URL` lines to @@ -822,7 +971,40 @@ class TestEggInfo: assert expected_line in pkg_info_lines expected_line = 'Project-URL: Link Two, https://example.com/two/' assert expected_line in pkg_info_lines - assert 'Metadata-Version: 1.2' in pkg_info_lines + assert self._extract_mv_version(pkg_info_lines) >= (1, 2) + + def test_license(self, tmpdir_cwd, env): + """Test single line license.""" + self._setup_script_with_requires( + "license='MIT'," + ) + code, data = environment.run_setup_py( + cmd=['egg_info'], + pypath=os.pathsep.join([env.paths['lib'], str(tmpdir_cwd)]), + data_stream=1, + ) + egg_info_dir = os.path.join('.', 'foo.egg-info') + with open(os.path.join(egg_info_dir, 'PKG-INFO')) as pkginfo_file: + pkg_info_lines = pkginfo_file.read().split('\n') + assert 'License: MIT' in pkg_info_lines + + def test_license_escape(self, tmpdir_cwd, env): + """Test license is escaped correctly if longer than one line.""" + self._setup_script_with_requires( + "license='This is a long license text \\nover multiple lines'," + ) + code, data = environment.run_setup_py( + cmd=['egg_info'], + pypath=os.pathsep.join([env.paths['lib'], str(tmpdir_cwd)]), + data_stream=1, + ) + egg_info_dir = os.path.join('.', 'foo.egg-info') + with open(os.path.join(egg_info_dir, 'PKG-INFO')) as pkginfo_file: + pkg_info_lines = pkginfo_file.read().split('\n') + + assert 'License: This is a long license text ' in pkg_info_lines + assert ' over multiple lines' in pkg_info_lines + assert 'text \n over multiple' in '\n'.join(pkg_info_lines) def test_python_requires_egg_info(self, tmpdir_cwd, env): self._setup_script_with_requires( @@ -840,7 +1022,7 @@ class TestEggInfo: with open(os.path.join(egg_info_dir, 'PKG-INFO')) as pkginfo_file: pkg_info_lines = pkginfo_file.read().split('\n') assert 'Requires-Python: >=2.7.12' in pkg_info_lines - assert 'Metadata-Version: 1.2' in pkg_info_lines + assert self._extract_mv_version(pkg_info_lines) >= (1, 2) def test_manifest_maker_warning_suppression(self): fixtures = [ @@ -886,7 +1068,7 @@ class TestEggInfo: def test_egg_info_tag_only_once(self, tmpdir_cwd, env): self._create_project() - build_files({ + path.build({ 'setup.cfg': DALS(""" [egg_info] tag_build = dev diff --git a/setuptools/tests/test_glob.py b/setuptools/tests/test_glob.py index a0728c5d..e99587f5 100644 --- a/setuptools/tests/test_glob.py +++ b/setuptools/tests/test_glob.py @@ -1,9 +1,8 @@ import pytest +from jaraco import path from setuptools.glob import glob -from .files import build_files - @pytest.mark.parametrize('tree, pattern, matches', ( ('', b'', []), @@ -31,5 +30,5 @@ from .files import build_files )) def test_glob(monkeypatch, tmpdir, tree, pattern, matches): monkeypatch.chdir(tmpdir) - build_files({name: '' for name in tree.split()}) + path.build({name: '' for name in tree.split()}) assert list(sorted(glob(pattern))) == list(sorted(matches)) diff --git a/setuptools/tests/test_namespaces.py b/setuptools/tests/test_namespaces.py index 6c8c522d..270f90c9 100644 --- a/setuptools/tests/test_namespaces.py +++ b/setuptools/tests/test_namespaces.py @@ -62,8 +62,9 @@ class TestNamespaces: target.mkdir() install_cmd = [ sys.executable, - '-m', 'easy_install', - '-d', str(target), + '-m', 'pip', + 'install', + '-t', str(target), str(pkg), ] with test.test.paths_on_pythonpath([str(target)]): diff --git a/setuptools/tests/test_setopt.py b/setuptools/tests/test_setopt.py index 0163f9af..36008632 100644 --- a/setuptools/tests/test_setopt.py +++ b/setuptools/tests/test_setopt.py @@ -28,3 +28,14 @@ class TestEdit: parser = self.parse_config(str(config)) assert parser.get('names', 'jaraco') == 'джарако' assert parser.get('names', 'other') == 'yes' + + def test_case_retained(self, tmpdir): + """ + When editing a file, case of keys should be retained. + """ + config = tmpdir.join('setup.cfg') + self.write_text(str(config), '[names]\nFoO=bAr') + setopt.edit_config(str(config), dict(names=dict(oTher='yes'))) + actual = config.read_text(encoding='ascii') + assert 'FoO' in actual + assert 'oTher' in actual diff --git a/setuptools/tests/test_sphinx_upload_docs.py b/setuptools/tests/test_sphinx_upload_docs.py new file mode 100644 index 00000000..cc5b8293 --- /dev/null +++ b/setuptools/tests/test_sphinx_upload_docs.py @@ -0,0 +1,38 @@ +import pytest + +from jaraco import path + +from setuptools.command.upload_docs import upload_docs +from setuptools.dist import Distribution + + +@pytest.fixture +def sphinx_doc_sample_project(tmpdir_cwd): + path.build({ + 'setup.py': 'from setuptools import setup; setup()', + 'build': { + 'docs': { + 'conf.py': 'project="test"', + 'index.rst': ".. toctree::\ + :maxdepth: 2\ + :caption: Contents:", + }, + }, + }) + + +@pytest.mark.usefixtures('sphinx_doc_sample_project') +class TestSphinxUploadDocs: + def test_sphinx_doc(self): + params = dict( + name='foo', + packages=['test'], + ) + dist = Distribution(params) + + cmd = upload_docs(dist) + + cmd.initialize_options() + assert cmd.upload_dir is None + assert cmd.has_sphinx() is True + cmd.finalize_options() diff --git a/setuptools/tests/test_test.py b/setuptools/tests/test_test.py index 180562e2..6bce8e20 100644 --- a/setuptools/tests/test_test.py +++ b/setuptools/tests/test_test.py @@ -1,4 +1,3 @@ -import mock from distutils import log import os @@ -6,12 +5,12 @@ import pytest from setuptools.command.test import test from setuptools.dist import Distribution -from setuptools.tests import ack_2to3 from .textwrap import DALS -SETUP_PY = DALS(""" +SETUP_PY = DALS( + """ from setuptools import setup setup(name='foo', @@ -19,9 +18,11 @@ SETUP_PY = DALS(""" namespace_packages=['name'], test_suite='name.space.tests.test_suite', ) - """) + """ +) -NS_INIT = DALS(""" +NS_INIT = DALS( + """ # -*- coding: Latin-1 -*- # Söme Arbiträry Ünicode to test Distribute Issüé 310 try: @@ -29,17 +30,20 @@ NS_INIT = DALS(""" except ImportError: from pkgutil import extend_path __path__ = extend_path(__path__, __name__) - """) + """ +) -TEST_PY = DALS(""" +TEST_PY = DALS( + """ import unittest class TestTest(unittest.TestCase): def test_test(self): - print "Foo" # Should fail under Python 3 unless 2to3 is used + print "Foo" # Should fail under Python 3 test_suite = unittest.makeSuite(TestTest) - """) + """ +) @pytest.fixture @@ -70,25 +74,6 @@ def quiet_log(): log.set_verbosity(0) -@pytest.mark.usefixtures('sample_test', 'quiet_log') -@ack_2to3 -def test_test(capfd): - params = dict( - name='foo', - packages=['name', 'name.space', 'name.space.tests'], - namespace_packages=['name'], - test_suite='name.space.tests.test_suite', - use_2to3=True, - ) - dist = Distribution(params) - dist.script_name = 'setup.py' - cmd = test(dist) - cmd.ensure_finalized() - cmd.run() - out, err = capfd.readouterr() - assert out == 'Foo\n' - - @pytest.mark.usefixtures('tmpdir_cwd', 'quiet_log') def test_tests_are_run_once(capfd): params = dict( @@ -104,13 +89,16 @@ def test_tests_are_run_once(capfd): with open('dummy/__init__.py', 'wt'): pass with open('dummy/test_dummy.py', 'wt') as f: - f.write(DALS( - """ + f.write( + DALS( + """ import unittest class TestTest(unittest.TestCase): def test_test(self): print('Foo') - """)) + """ + ) + ) dist = Distribution(params) dist.script_name = 'setup.py' cmd = test(dist) @@ -118,54 +106,3 @@ def test_tests_are_run_once(capfd): cmd.run() out, err = capfd.readouterr() assert out == 'Foo\n' - - -@pytest.mark.usefixtures('sample_test') -@ack_2to3 -def test_warns_deprecation(capfd): - params = dict( - name='foo', - packages=['name', 'name.space', 'name.space.tests'], - namespace_packages=['name'], - test_suite='name.space.tests.test_suite', - use_2to3=True - ) - dist = Distribution(params) - dist.script_name = 'setup.py' - cmd = test(dist) - cmd.ensure_finalized() - cmd.announce = mock.Mock() - cmd.run() - capfd.readouterr() - msg = ( - "WARNING: Testing via this command is deprecated and will be " - "removed in a future version. Users looking for a generic test " - "entry point independent of test runner are encouraged to use " - "tox." - ) - cmd.announce.assert_any_call(msg, log.WARN) - - -@pytest.mark.usefixtures('sample_test') -@ack_2to3 -def test_deprecation_stderr(capfd): - params = dict( - name='foo', - packages=['name', 'name.space', 'name.space.tests'], - namespace_packages=['name'], - test_suite='name.space.tests.test_suite', - use_2to3=True - ) - dist = Distribution(params) - dist.script_name = 'setup.py' - cmd = test(dist) - cmd.ensure_finalized() - cmd.run() - out, err = capfd.readouterr() - msg = ( - "WARNING: Testing via this command is deprecated and will be " - "removed in a future version. Users looking for a generic test " - "entry point independent of test runner are encouraged to use " - "tox.\n" - ) - assert msg in err diff --git a/setuptools/tests/test_upload_docs.py b/setuptools/tests/test_upload_docs.py index a26e32a6..55978aad 100644 --- a/setuptools/tests/test_upload_docs.py +++ b/setuptools/tests/test_upload_docs.py @@ -3,6 +3,7 @@ import zipfile import contextlib import pytest +from jaraco import path from setuptools.command.upload_docs import upload_docs from setuptools.dist import Distribution @@ -10,28 +11,20 @@ from setuptools.dist import Distribution from .textwrap import DALS from . import contexts -SETUP_PY = DALS( - """ - from setuptools import setup - - setup(name='foo') - """) - @pytest.fixture def sample_project(tmpdir_cwd): - # setup.py - with open('setup.py', 'wt') as f: - f.write(SETUP_PY) - - os.mkdir('build') - - # A test document. - with open('build/index.html', 'w') as f: - f.write("Hello world.") - - # An empty folder. - os.mkdir('build/empty') + path.build({ + 'setup.py': DALS(""" + from setuptools import setup + + setup(name='foo') + """), + 'build': { + 'index.html': 'Hello world.', + 'empty': {}, + } + }) @pytest.mark.usefixtures('sample_project') diff --git a/setuptools/tests/test_virtualenv.py b/setuptools/tests/test_virtualenv.py index 21dea5bb..462e20c7 100644 --- a/setuptools/tests/test_virtualenv.py +++ b/setuptools/tests/test_virtualenv.py @@ -1,6 +1,7 @@ import glob import os import sys +import itertools import pathlib @@ -40,14 +41,12 @@ def bare_virtualenv(): yield venv -SOURCE_DIR = os.path.join(os.path.dirname(__file__), '../..') - - -def test_clean_env_install(bare_virtualenv): +def test_clean_env_install(bare_virtualenv, tmp_src): """ Check setuptools can be installed in a clean environment. """ - bare_virtualenv.run(['python', 'setup.py', 'install'], cd=SOURCE_DIR) + cmd = [bare_virtualenv.python, 'setup.py', 'install'] + bare_virtualenv.run(cmd, cd=tmp_src) def _get_pip_versions(): @@ -68,24 +67,38 @@ def _get_pip_versions(): # No network, disable most of these tests network = False + def mark(param, *marks): + if not isinstance(param, type(pytest.param(''))): + param = pytest.param(param) + return param._replace(marks=param.marks + marks) + + def skip_network(param): + return param if network else mark(param, pytest.mark.skip(reason="no network")) + + issue2599 = pytest.mark.skipif( + sys.version_info > (3, 10), + reason="pypa/setuptools#2599", + ) + network_versions = [ - 'pip==9.0.3', - 'pip==10.0.1', - 'pip==18.1', - 'pip==19.0.1', - 'https://github.com/pypa/pip/archive/master.zip', + mark('pip==9.0.3', issue2599), + mark('pip==10.0.1', issue2599), + mark('pip==18.1', issue2599), + mark('pip==19.3.1', pytest.mark.xfail(reason='pypa/pip#6599')), + 'pip==20.0.2', + 'https://github.com/pypa/pip/archive/main.zip', ] - versions = [None] + [ - pytest.param(v, **({} if network else {'marks': pytest.mark.skip})) - for v in network_versions - ] + versions = itertools.chain( + [None], + map(skip_network, network_versions) + ) - return versions + return list(versions) @pytest.mark.parametrize('pip_version', _get_pip_versions()) -def test_pip_upgrade_from_source(pip_version, virtualenv): +def test_pip_upgrade_from_source(pip_version, tmp_src, virtualenv): """ Check pip can upgrade setuptools from source. """ @@ -104,7 +117,7 @@ def test_pip_upgrade_from_source(pip_version, virtualenv): virtualenv.run(' && '.join(( 'python setup.py -q sdist -d {dist}', 'python setup.py -q bdist_wheel -d {dist}', - )).format(dist=dist_dir), cd=SOURCE_DIR) + )).format(dist=dist_dir), cd=tmp_src) sdist = glob.glob(os.path.join(dist_dir, '*.zip'))[0] wheel = glob.glob(os.path.join(dist_dir, '*.whl'))[0] # Then update from wheel. @@ -113,12 +126,12 @@ def test_pip_upgrade_from_source(pip_version, virtualenv): virtualenv.run('pip install --no-cache-dir --upgrade ' + sdist) -def _check_test_command_install_requirements(virtualenv, tmpdir): +def _check_test_command_install_requirements(virtualenv, tmpdir, cwd): """ Check the test command will install all required dependencies. """ # Install setuptools. - virtualenv.run('python setup.py develop', cd=SOURCE_DIR) + virtualenv.run('python setup.py develop', cd=cwd) def sdist(distname, version): dist_path = tmpdir.join('%s-%s.tar.gz' % (distname, version)) @@ -175,22 +188,21 @@ def _check_test_command_install_requirements(virtualenv, tmpdir): assert tmpdir.join('success').check() -def test_test_command_install_requirements(virtualenv, tmpdir): +def test_test_command_install_requirements(virtualenv, tmpdir, request): # Ensure pip/wheel packages are installed. virtualenv.run( "python -c \"__import__('pkg_resources').require(['pip', 'wheel'])\"") - _check_test_command_install_requirements(virtualenv, tmpdir) - - -def test_test_command_install_requirements_when_using_easy_install( - bare_virtualenv, tmpdir): - _check_test_command_install_requirements(bare_virtualenv, tmpdir) + # uninstall setuptools so that 'setup.py develop' works + virtualenv.run("python -m pip uninstall -y setuptools") + # disable index URL so bits and bobs aren't requested from PyPI + virtualenv.env['PIP_NO_INDEX'] = '1' + _check_test_command_install_requirements(virtualenv, tmpdir, request.config.rootdir) -def test_no_missing_dependencies(bare_virtualenv): +def test_no_missing_dependencies(bare_virtualenv, request): """ Quick and dirty test to ensure all external dependencies are vendored. """ for command in ('upload',): # sorted(distutils.command.__all__): - bare_virtualenv.run( - ['python', 'setup.py', command, '-h'], cd=SOURCE_DIR) + cmd = [bare_virtualenv.python, 'setup.py', command, '-h'] + bare_virtualenv.run(cmd, cd=request.config.rootdir) diff --git a/setuptools/tests/test_wheel.py b/setuptools/tests/test_wheel.py index e56eac14..7345b135 100644 --- a/setuptools/tests/test_wheel.py +++ b/setuptools/tests/test_wheel.py @@ -15,6 +15,7 @@ import sys import zipfile import pytest +from jaraco import path from pkg_resources import Distribution, PathMetadata, PY_MAJOR from setuptools.extern.packaging.utils import canonicalize_name @@ -22,7 +23,6 @@ from setuptools.extern.packaging.tags import parse_tag from setuptools.wheel import Wheel from .contexts import tempdir -from .files import build_files from .textwrap import DALS @@ -91,7 +91,7 @@ def build_wheel(extra_file_defs=None, **kwargs): if extra_file_defs: file_defs.update(extra_file_defs) with tempdir() as source_dir: - build_files(file_defs, source_dir) + path.build(file_defs, source_dir) subprocess.check_call((sys.executable, 'setup.py', '-q', 'bdist_wheel'), cwd=source_dir) yield glob.glob(os.path.join(source_dir, 'dist', '*.whl'))[0] diff --git a/setuptools/tests/test_windows_wrappers.py b/setuptools/tests/test_windows_wrappers.py index fa647de8..8ac9bd07 100644 --- a/setuptools/tests/test_windows_wrappers.py +++ b/setuptools/tests/test_windows_wrappers.py @@ -13,6 +13,7 @@ are to wrap. """ import sys +import platform import textwrap import subprocess @@ -51,10 +52,20 @@ class WrapperTester: f.write(w) +def win_launcher_exe(prefix): + """ A simple routine to select launcher script based on platform.""" + assert prefix in ('cli', 'gui') + if platform.machine() == "ARM64": + return "{}-arm64.exe".format(prefix) + else: + return "{}-32.exe".format(prefix) + + class TestCLI(WrapperTester): script_name = 'foo-script.py' - wrapper_source = 'cli-32.exe' wrapper_name = 'foo.exe' + wrapper_source = win_launcher_exe('cli') + script_tmpl = textwrap.dedent(""" #!%(python_exe)s import sys @@ -155,7 +166,7 @@ class TestGUI(WrapperTester): ----------------------- """ script_name = 'bar-script.pyw' - wrapper_source = 'gui-32.exe' + wrapper_source = win_launcher_exe('gui') wrapper_name = 'bar.exe' script_tmpl = textwrap.dedent(""" @@ -167,7 +178,7 @@ class TestGUI(WrapperTester): """).strip() def test_basic(self, tmpdir): - """Test the GUI version with the simple scipt, bar-script.py""" + """Test the GUI version with the simple script, bar-script.py""" self.create_script(tmpdir) cmd = [ |