diff options
Diffstat (limited to 'git/objects')
-rw-r--r-- | git/objects/commit.py | 2 | ||||
-rw-r--r-- | git/objects/fun.py | 61 | ||||
-rw-r--r-- | git/objects/submodule/base.py | 89 | ||||
-rw-r--r-- | git/objects/submodule/root.py | 34 | ||||
-rw-r--r-- | git/objects/submodule/util.py | 23 | ||||
-rw-r--r-- | git/objects/tree.py | 49 | ||||
-rw-r--r-- | git/objects/util.py | 57 |
7 files changed, 201 insertions, 114 deletions
diff --git a/git/objects/commit.py b/git/objects/commit.py index 81978ae8..65a87591 100644 --- a/git/objects/commit.py +++ b/git/objects/commit.py @@ -80,7 +80,7 @@ class Commit(base.Object, TraversableIterableObj, Diffable, Serializable): "message", "parents", "encoding", "gpgsig") _id_attribute_ = "hexsha" - def __init__(self, repo: 'Repo', binsha: bytes, tree: 'Tree' = None, + def __init__(self, repo: 'Repo', binsha: bytes, tree: Union['Tree', None] = None, author: Union[Actor, None] = None, authored_date: Union[int, None] = None, author_tz_offset: Union[None, float] = None, diff --git a/git/objects/fun.py b/git/objects/fun.py index 339a53b8..fc2ea1e7 100644 --- a/git/objects/fun.py +++ b/git/objects/fun.py @@ -1,6 +1,7 @@ """Module with functions which are supposed to be as fast as possible""" from stat import S_ISDIR + from git.compat import ( safe_decode, defenc @@ -8,8 +9,14 @@ from git.compat import ( # typing ---------------------------------------------- -from typing import List, Tuple +from typing import Callable, List, MutableSequence, Sequence, Tuple, TYPE_CHECKING, Union, overload + +if TYPE_CHECKING: + from _typeshed import ReadableBuffer + from git import GitCmdObjectDB +EntryTup = Tuple[bytes, int, str] # same as TreeCacheTup in tree.py +EntryTupOrNone = Union[EntryTup, None] # --------------------------------------------------- @@ -18,7 +25,7 @@ __all__ = ('tree_to_stream', 'tree_entries_from_data', 'traverse_trees_recursive 'traverse_tree_recursive') -def tree_to_stream(entries, write): +def tree_to_stream(entries: Sequence[EntryTup], write: Callable[['ReadableBuffer'], Union[int, None]]) -> None: """Write the give list of entries into a stream using its write method :param entries: **sorted** list of tuples with (binsha, mode, name) :param write: write method which takes a data string""" @@ -42,12 +49,14 @@ def tree_to_stream(entries, write): # According to my tests, this is exactly what git does, that is it just # takes the input literally, which appears to be utf8 on linux. if isinstance(name, str): - name = name.encode(defenc) - write(b''.join((mode_str, b' ', name, b'\0', binsha))) + name_bytes = name.encode(defenc) + else: + name_bytes = name + write(b''.join((mode_str, b' ', name_bytes, b'\0', binsha))) # END for each item -def tree_entries_from_data(data: bytes) -> List[Tuple[bytes, int, str]]: +def tree_entries_from_data(data: bytes) -> List[EntryTup]: """Reads the binary representation of a tree and returns tuples of Tree items :param data: data block with tree data (as bytes) :return: list(tuple(binsha, mode, tree_relative_path), ...)""" @@ -93,11 +102,13 @@ def tree_entries_from_data(data: bytes) -> List[Tuple[bytes, int, str]]: return out -def _find_by_name(tree_data, name, is_dir, start_at): +def _find_by_name(tree_data: MutableSequence[EntryTupOrNone], name: str, is_dir: bool, start_at: int + ) -> EntryTupOrNone: """return data entry matching the given name and tree mode or None. Before the item is returned, the respective data item is set None in the tree_data list to mark it done""" + try: item = tree_data[start_at] if item and item[2] == name and S_ISDIR(item[1]) == is_dir: @@ -115,16 +126,27 @@ def _find_by_name(tree_data, name, is_dir, start_at): return None -def _to_full_path(item, path_prefix): +@ overload +def _to_full_path(item: None, path_prefix: str) -> None: + ... + + +@ overload +def _to_full_path(item: EntryTup, path_prefix: str) -> EntryTup: + ... + + +def _to_full_path(item: EntryTupOrNone, path_prefix: str) -> EntryTupOrNone: """Rebuild entry with given path prefix""" if not item: return item return (item[0], item[1], path_prefix + item[2]) -def traverse_trees_recursive(odb, tree_shas, path_prefix): +def traverse_trees_recursive(odb: 'GitCmdObjectDB', tree_shas: Sequence[Union[bytes, None]], + path_prefix: str) -> List[Tuple[EntryTupOrNone, ...]]: """ - :return: list with entries according to the given binary tree-shas. + :return: list of list with entries according to the given binary tree-shas. The result is encoded in a list of n tuple|None per blob/commit, (n == len(tree_shas)), where * [0] == 20 byte sha @@ -137,28 +159,31 @@ def traverse_trees_recursive(odb, tree_shas, path_prefix): :param path_prefix: a prefix to be added to the returned paths on this level, set it '' for the first iteration :note: The ordering of the returned items will be partially lost""" - trees_data = [] + trees_data: List[List[EntryTupOrNone]] = [] + nt = len(tree_shas) for tree_sha in tree_shas: if tree_sha is None: - data = [] + data: List[EntryTupOrNone] = [] else: - data = tree_entries_from_data(odb.stream(tree_sha).read()) + # make new list for typing as list invariant + data = [x for x in tree_entries_from_data(odb.stream(tree_sha).read())] # END handle muted trees trees_data.append(data) # END for each sha to get data for - out = [] - out_append = out.append + out: List[Tuple[EntryTupOrNone, ...]] = [] # find all matching entries and recursively process them together if the match # is a tree. If the match is a non-tree item, put it into the result. # Processed items will be set None for ti, tree_data in enumerate(trees_data): + for ii, item in enumerate(tree_data): if not item: continue # END skip already done items + entries: List[EntryTupOrNone] entries = [None for _ in range(nt)] entries[ti] = item _sha, mode, name = item @@ -170,16 +195,16 @@ def traverse_trees_recursive(odb, tree_shas, path_prefix): for tio in range(ti + 1, ti + nt): tio = tio % nt entries[tio] = _find_by_name(trees_data[tio], name, is_dir, ii) - # END for each other item data + # END for each other item data # if we are a directory, enter recursion if is_dir: out.extend(traverse_trees_recursive( odb, [((ei and ei[0]) or None) for ei in entries], path_prefix + name + '/')) else: - out_append(tuple(_to_full_path(e, path_prefix) for e in entries)) - # END handle recursion + out.append(tuple(_to_full_path(e, path_prefix) for e in entries)) + # END handle recursion # finally mark it done tree_data[ii] = None # END for each item @@ -190,7 +215,7 @@ def traverse_trees_recursive(odb, tree_shas, path_prefix): return out -def traverse_tree_recursive(odb, tree_sha, path_prefix): +def traverse_tree_recursive(odb: 'GitCmdObjectDB', tree_sha: bytes, path_prefix: str) -> List[EntryTup]: """ :return: list of entries of the tree pointed to by the binary tree_sha. An entry has the following format: diff --git a/git/objects/submodule/base.py b/git/objects/submodule/base.py index c95b66f2..b485dbf6 100644 --- a/git/objects/submodule/base.py +++ b/git/objects/submodule/base.py @@ -49,12 +49,13 @@ from .util import ( # typing ---------------------------------------------------------------------- -from typing import Dict, TYPE_CHECKING +from typing import Callable, Dict, Mapping, Sequence, TYPE_CHECKING, cast from typing import Any, Iterator, Union -from git.types import Commit_ish, PathLike +from git.types import Commit_ish, PathLike, TBD if TYPE_CHECKING: + from git.index import IndexFile from git.repo import Repo @@ -114,7 +115,7 @@ class Submodule(IndexObject, TraversableIterableObj): path: Union[PathLike, None] = None, name: Union[str, None] = None, parent_commit: Union[Commit_ish, None] = None, - url: str = None, + url: Union[str, None] = None, branch_path: Union[PathLike, None] = None ) -> None: """Initialize this instance with its attributes. We only document the ones @@ -131,14 +132,14 @@ class Submodule(IndexObject, TraversableIterableObj): if url is not None: self._url = url if branch_path is not None: - assert isinstance(branch_path, str) + # assert isinstance(branch_path, str) self._branch_path = branch_path if name is not None: self._name = name def _set_cache_(self, attr: str) -> None: if attr in ('path', '_url', '_branch_path'): - reader = self.config_reader() + reader: SectionConstraint = self.config_reader() # default submodule values try: self.path = reader.get('path') @@ -226,7 +227,7 @@ class Submodule(IndexObject, TraversableIterableObj): return SubmoduleConfigParser(fp_module, read_only=read_only) - def _clear_cache(self): + def _clear_cache(self) -> None: # clear the possibly changed values for name in self._cache_attrs: try: @@ -246,7 +247,7 @@ class Submodule(IndexObject, TraversableIterableObj): def _config_parser_constrained(self, read_only: bool) -> SectionConstraint: """:return: Config Parser constrained to our submodule in read or write mode""" try: - pc = self.parent_commit + pc: Union['Commit_ish', None] = self.parent_commit except ValueError: pc = None # end handle empty parent repository @@ -255,10 +256,12 @@ class Submodule(IndexObject, TraversableIterableObj): return SectionConstraint(parser, sm_section(self.name)) @classmethod - def _module_abspath(cls, parent_repo, path, name): + def _module_abspath(cls, parent_repo: 'Repo', path: PathLike, name: str) -> PathLike: if cls._need_gitfile_submodules(parent_repo.git): return osp.join(parent_repo.git_dir, 'modules', name) - return osp.join(parent_repo.working_tree_dir, path) + if parent_repo.working_tree_dir: + return osp.join(parent_repo.working_tree_dir, path) + raise NotADirectoryError() # end @classmethod @@ -286,7 +289,7 @@ class Submodule(IndexObject, TraversableIterableObj): return clone @classmethod - def _to_relative_path(cls, parent_repo, path): + def _to_relative_path(cls, parent_repo: 'Repo', path: PathLike) -> PathLike: """:return: a path guaranteed to be relative to the given parent - repository :raise ValueError: if path is not contained in the parent repository's working tree""" path = to_native_path_linux(path) @@ -294,7 +297,7 @@ class Submodule(IndexObject, TraversableIterableObj): path = path[:-1] # END handle trailing slash - if osp.isabs(path): + if osp.isabs(path) and parent_repo.working_tree_dir: working_tree_linux = to_native_path_linux(parent_repo.working_tree_dir) if not path.startswith(working_tree_linux): raise ValueError("Submodule checkout path '%s' needs to be within the parents repository at '%s'" @@ -308,7 +311,7 @@ class Submodule(IndexObject, TraversableIterableObj): return path @classmethod - def _write_git_file_and_module_config(cls, working_tree_dir, module_abspath): + def _write_git_file_and_module_config(cls, working_tree_dir: PathLike, module_abspath: PathLike) -> None: """Writes a .git file containing a(preferably) relative path to the actual git module repository. It is an error if the module_abspath cannot be made into a relative path, relative to the working_tree_dir :note: will overwrite existing files ! @@ -335,7 +338,8 @@ class Submodule(IndexObject, TraversableIterableObj): @classmethod def add(cls, repo: 'Repo', name: str, path: PathLike, url: Union[str, None] = None, - branch=None, no_checkout: bool = False, depth=None, env=None, clone_multi_options=None + branch: Union[str, None] = None, no_checkout: bool = False, depth: Union[int, None] = None, + env: Union[Mapping[str, str], None] = None, clone_multi_options: Union[Sequence[TBD], None] = None ) -> 'Submodule': """Add a new submodule to the given repository. This will alter the index as well as the .gitmodules file, but will not create a new commit. @@ -391,7 +395,7 @@ class Submodule(IndexObject, TraversableIterableObj): if sm.exists(): # reretrieve submodule from tree try: - sm = repo.head.commit.tree[path] + sm = repo.head.commit.tree[str(path)] sm._name = name return sm except KeyError: @@ -414,12 +418,14 @@ class Submodule(IndexObject, TraversableIterableObj): # END check url # END verify urls match - mrepo = None + mrepo: Union[Repo, None] = None + if url is None: if not has_module: raise ValueError("A URL was not given and a repository did not exist at %s" % path) # END check url mrepo = sm.module() + # assert isinstance(mrepo, git.Repo) urls = [r.url for r in mrepo.remotes] if not urls: raise ValueError("Didn't find any remote url in repository at %s" % sm.abspath) @@ -427,7 +433,7 @@ class Submodule(IndexObject, TraversableIterableObj): url = urls[0] else: # clone new repo - kwargs: Dict[str, Union[bool, int]] = {'n': no_checkout} + kwargs: Dict[str, Union[bool, int, Sequence[TBD]]] = {'n': no_checkout} if not branch_is_default: kwargs['b'] = br.name # END setup checkout-branch @@ -451,6 +457,8 @@ class Submodule(IndexObject, TraversableIterableObj): # otherwise there is a '-' character in front of the submodule listing # a38efa84daef914e4de58d1905a500d8d14aaf45 mymodule (v0.9.0-1-ga38efa8) # -a38efa84daef914e4de58d1905a500d8d14aaf45 submodules/intermediate/one + writer: Union[GitConfigParser, SectionConstraint] + with sm.repo.config_writer() as writer: writer.set_value(sm_section(name), 'url', url) @@ -467,13 +475,15 @@ class Submodule(IndexObject, TraversableIterableObj): sm._branch_path = br.path # we deliberately assume that our head matches our index ! - sm.binsha = mrepo.head.commit.binsha + sm.binsha = mrepo.head.commit.binsha # type: ignore index.add([sm], write=True) return sm - def update(self, recursive=False, init=True, to_latest_revision=False, progress=None, dry_run=False, - force=False, keep_going=False, env=None, clone_multi_options=None): + def update(self, recursive: bool = False, init: bool = True, to_latest_revision: bool = False, + progress: Union['UpdateProgress', None] = None, dry_run: bool = False, + force: bool = False, keep_going: bool = False, env: Union[Mapping[str, str], None] = None, + clone_multi_options: Union[Sequence[TBD], None] = None): """Update the repository of this submodule to point to the checkout we point at with the binsha of this instance. @@ -580,6 +590,7 @@ class Submodule(IndexObject, TraversableIterableObj): if not dry_run: # see whether we have a valid branch to checkout try: + mrepo = cast('Repo', mrepo) # find a remote which has our branch - we try to be flexible remote_branch = find_first_remote_branch(mrepo.remotes, self.branch_name) local_branch = mkhead(mrepo, self.branch_path) @@ -640,7 +651,7 @@ class Submodule(IndexObject, TraversableIterableObj): may_reset = True if mrepo.head.commit.binsha != self.NULL_BIN_SHA: base_commit = mrepo.merge_base(mrepo.head.commit, hexsha) - if len(base_commit) == 0 or base_commit[0].hexsha == hexsha: + if len(base_commit) == 0 or (base_commit[0] is not None and base_commit[0].hexsha == hexsha): if force: msg = "Will force checkout or reset on local branch that is possibly in the future of" msg += "the commit it will be checked out to, effectively 'forgetting' new commits" @@ -807,7 +818,8 @@ class Submodule(IndexObject, TraversableIterableObj): return self @unbare_repo - def remove(self, module=True, force=False, configuration=True, dry_run=False): + def remove(self, module: bool = True, force: bool = False, + configuration: bool = True, dry_run: bool = False) -> 'Submodule': """Remove this submodule from the repository. This will remove our entry from the .gitmodules file and the entry in the .git / config file. @@ -861,7 +873,7 @@ class Submodule(IndexObject, TraversableIterableObj): # TODO: If we run into permission problems, we have a highly inconsistent # state. Delete the .git folders last, start with the submodules first mp = self.abspath - method = None + method: Union[None, Callable[[PathLike], None]] = None if osp.islink(mp): method = os.remove elif osp.isdir(mp): @@ -914,7 +926,7 @@ class Submodule(IndexObject, TraversableIterableObj): import gc gc.collect() try: - rmtree(wtd) + rmtree(str(wtd)) except Exception as ex: if HIDE_WINDOWS_KNOWN_ERRORS: raise SkipTest("FIXME: fails with: PermissionError\n {}".format(ex)) from ex @@ -928,7 +940,7 @@ class Submodule(IndexObject, TraversableIterableObj): rmtree(git_dir) except Exception as ex: if HIDE_WINDOWS_KNOWN_ERRORS: - raise SkipTest("FIXME: fails with: PermissionError\n %s", ex) from ex + raise SkipTest(f"FIXME: fails with: PermissionError\n {ex}") from ex else: raise # end handle separate bare repository @@ -952,6 +964,8 @@ class Submodule(IndexObject, TraversableIterableObj): # now git config - need the config intact, otherwise we can't query # information anymore + writer: Union[GitConfigParser, SectionConstraint] + with self.repo.config_writer() as writer: writer.remove_section(sm_section(self.name)) @@ -961,7 +975,7 @@ class Submodule(IndexObject, TraversableIterableObj): return self - def set_parent_commit(self, commit: Union[Commit_ish, None], check=True): + def set_parent_commit(self, commit: Union[Commit_ish, None], check: bool = True) -> 'Submodule': """Set this instance to use the given commit whose tree is supposed to contain the .gitmodules blob. @@ -1009,7 +1023,7 @@ class Submodule(IndexObject, TraversableIterableObj): return self @unbare_repo - def config_writer(self, index=None, write=True): + def config_writer(self, index: Union['IndexFile', None] = None, write: bool = True) -> SectionConstraint: """:return: a config writer instance allowing you to read and write the data belonging to this submodule into the .gitmodules file. @@ -1030,7 +1044,7 @@ class Submodule(IndexObject, TraversableIterableObj): return writer @unbare_repo - def rename(self, new_name): + def rename(self, new_name: str) -> 'Submodule': """Rename this submodule :note: This method takes care of renaming the submodule in various places, such as @@ -1065,13 +1079,14 @@ class Submodule(IndexObject, TraversableIterableObj): destination_module_abspath = self._module_abspath(self.repo, self.path, new_name) source_dir = mod.git_dir # Let's be sure the submodule name is not so obviously tied to a directory - if destination_module_abspath.startswith(mod.git_dir): + if str(destination_module_abspath).startswith(str(mod.git_dir)): tmp_dir = self._module_abspath(self.repo, self.path, str(uuid.uuid4())) os.renames(source_dir, tmp_dir) source_dir = tmp_dir # end handle self-containment os.renames(source_dir, destination_module_abspath) - self._write_git_file_and_module_config(mod.working_tree_dir, destination_module_abspath) + if mod.working_tree_dir: + self._write_git_file_and_module_config(mod.working_tree_dir, destination_module_abspath) # end move separate git repository return self @@ -1081,7 +1096,7 @@ class Submodule(IndexObject, TraversableIterableObj): #{ Query Interface @unbare_repo - def module(self): + def module(self) -> 'Repo': """:return: Repo instance initialized from the repository at our submodule path :raise InvalidGitRepositoryError: if a repository was not available. This could also mean that it was not yet initialized""" @@ -1098,7 +1113,7 @@ class Submodule(IndexObject, TraversableIterableObj): raise InvalidGitRepositoryError("Repository at %r was not yet checked out" % module_checkout_abspath) # END handle exceptions - def module_exists(self): + def module_exists(self) -> bool: """:return: True if our module exists and is a valid git repository. See module() method""" try: self.module() @@ -1107,7 +1122,7 @@ class Submodule(IndexObject, TraversableIterableObj): return False # END handle exception - def exists(self): + def exists(self) -> bool: """ :return: True if the submodule exists, False otherwise. Please note that a submodule may exist ( in the .gitmodules file) even though its module @@ -1148,26 +1163,26 @@ class Submodule(IndexObject, TraversableIterableObj): return mkhead(self.module(), self._branch_path) @property - def branch_path(self): + def branch_path(self) -> PathLike: """ :return: full(relative) path as string to the branch we would checkout from the remote and track""" return self._branch_path @property - def branch_name(self): + def branch_name(self) -> str: """:return: the name of the branch, which is the shortest possible branch name""" # use an instance method, for this we create a temporary Head instance # which uses a repository that is available at least ( it makes no difference ) return git.Head(self.repo, self._branch_path).name @property - def url(self): + def url(self) -> str: """:return: The url to the repository which our module - repository refers to""" return self._url @property - def parent_commit(self): + def parent_commit(self) -> 'Commit_ish': """:return: Commit instance with the tree containing the .gitmodules file :note: will always point to the current head's commit if it was not set explicitly""" if self._parent_commit is None: @@ -1175,7 +1190,7 @@ class Submodule(IndexObject, TraversableIterableObj): return self._parent_commit @property - def name(self): + def name(self) -> str: """:return: The name of this submodule. It is used to identify it within the .gitmodules file. :note: by default, the name is the path at which to find the submodule, but diff --git a/git/objects/submodule/root.py b/git/objects/submodule/root.py index 0af48710..bcac5419 100644 --- a/git/objects/submodule/root.py +++ b/git/objects/submodule/root.py @@ -10,6 +10,18 @@ import git import logging +# typing ------------------------------------------------------------------- + +from typing import TYPE_CHECKING, Union + +from git.types import Commit_ish + +if TYPE_CHECKING: + from git.repo import Repo + from git.util import IterableList + +# ---------------------------------------------------------------------------- + __all__ = ["RootModule", "RootUpdateProgress"] log = logging.getLogger('git.objects.submodule.root') @@ -42,7 +54,7 @@ class RootModule(Submodule): k_root_name = '__ROOT__' - def __init__(self, repo): + def __init__(self, repo: 'Repo'): # repo, binsha, mode=None, path=None, name = None, parent_commit=None, url=None, ref=None) super(RootModule, self).__init__( repo, @@ -55,15 +67,17 @@ class RootModule(Submodule): branch_path=git.Head.to_full_path(self.k_head_default) ) - def _clear_cache(self): + def _clear_cache(self) -> None: """May not do anything""" pass #{ Interface - def update(self, previous_commit=None, recursive=True, force_remove=False, init=True, - to_latest_revision=False, progress=None, dry_run=False, force_reset=False, - keep_going=False): + def update(self, previous_commit: Union[Commit_ish, None] = None, # type: ignore[override] + recursive: bool = True, force_remove: bool = False, init: bool = True, + to_latest_revision: bool = False, progress: Union[None, 'RootUpdateProgress'] = None, + dry_run: bool = False, force_reset: bool = False, keep_going: bool = False + ) -> 'RootModule': """Update the submodules of this repository to the current HEAD commit. This method behaves smartly by determining changes of the path of a submodules repository, next to changes to the to-be-checked-out commit or the branch to be @@ -128,8 +142,8 @@ class RootModule(Submodule): previous_commit = repo.commit(previous_commit) # obtain commit object # END handle previous commit - psms = self.list_items(repo, parent_commit=previous_commit) - sms = self.list_items(repo) + psms: 'IterableList[Submodule]' = self.list_items(repo, parent_commit=previous_commit) + sms: 'IterableList[Submodule]' = self.list_items(repo) spsms = set(psms) ssms = set(sms) @@ -162,8 +176,8 @@ class RootModule(Submodule): csms = (spsms & ssms) len_csms = len(csms) for i, csm in enumerate(csms): - psm = psms[csm.name] - sm = sms[csm.name] + psm: 'Submodule' = psms[csm.name] + sm: 'Submodule' = sms[csm.name] # PATH CHANGES ############## @@ -343,7 +357,7 @@ class RootModule(Submodule): return self - def module(self): + def module(self) -> 'Repo': """:return: the actual repository containing the submodules""" return self.repo #} END interface diff --git a/git/objects/submodule/util.py b/git/objects/submodule/util.py index 5290000b..a776af88 100644 --- a/git/objects/submodule/util.py +++ b/git/objects/submodule/util.py @@ -5,11 +5,20 @@ from io import BytesIO import weakref -from typing import Any, TYPE_CHECKING, Union +# typing ----------------------------------------------------------------------- + +from typing import Any, Sequence, TYPE_CHECKING, Union + +from git.types import PathLike if TYPE_CHECKING: from .base import Submodule from weakref import ReferenceType + from git.repo import Repo + from git.refs import Head + from git import Remote + from git.refs import RemoteReference + __all__ = ('sm_section', 'sm_name', 'mkhead', 'find_first_remote_branch', 'SubmoduleConfigParser') @@ -17,23 +26,23 @@ __all__ = ('sm_section', 'sm_name', 'mkhead', 'find_first_remote_branch', #{ Utilities -def sm_section(name): +def sm_section(name: str) -> str: """:return: section title used in .gitmodules configuration file""" - return 'submodule "%s"' % name + return f'submodule "{name}"' -def sm_name(section): +def sm_name(section: str) -> str: """:return: name of the submodule as parsed from the section name""" section = section.strip() return section[11:-1] -def mkhead(repo, path): +def mkhead(repo: 'Repo', path: PathLike) -> 'Head': """:return: New branch/head instance""" return git.Head(repo, git.Head.to_full_path(path)) -def find_first_remote_branch(remotes, branch_name): +def find_first_remote_branch(remotes: Sequence['Remote'], branch_name: str) -> 'RemoteReference': """Find the remote branch matching the name of the given branch or raise InvalidGitRepositoryError""" for remote in remotes: try: @@ -92,7 +101,7 @@ class SubmoduleConfigParser(GitConfigParser): #{ Overridden Methods def write(self) -> None: - rval = super(SubmoduleConfigParser, self).write() + rval: None = super(SubmoduleConfigParser, self).write() self.flush_to_index() return rval # END overridden methods diff --git a/git/objects/tree.py b/git/objects/tree.py index 2e8d8a79..a9656c1d 100644 --- a/git/objects/tree.py +++ b/git/objects/tree.py @@ -4,8 +4,8 @@ # This module is part of GitPython and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php -from git.util import join_path -import git.diff as diff +from git.util import IterableList, join_path +import git.diff as git_diff from git.util import to_bin_sha from . import util @@ -21,8 +21,8 @@ from .fun import ( # typing ------------------------------------------------- -from typing import (Callable, Dict, Generic, Iterable, Iterator, List, - Tuple, Type, TypeVar, Union, cast, TYPE_CHECKING) +from typing import (Any, Callable, Dict, Iterable, Iterator, List, + Tuple, Type, Union, cast, TYPE_CHECKING) from git.types import PathLike, TypeGuard @@ -30,10 +30,15 @@ if TYPE_CHECKING: from git.repo import Repo from io import BytesIO -T_Tree_cache = TypeVar('T_Tree_cache', bound=Tuple[bytes, int, str]) +TreeCacheTup = Tuple[bytes, int, str] + TraversedTreeTup = Union[Tuple[Union['Tree', None], IndexObjUnion, Tuple['Submodule', 'Submodule']]] + +def is_tree_cache(inp: Tuple[bytes, int, str]) -> TypeGuard[TreeCacheTup]: + return isinstance(inp[0], bytes) and isinstance(inp[1], int) and isinstance([inp], str) + #-------------------------------------------------------- @@ -42,9 +47,9 @@ cmp: Callable[[str, str], int] = lambda a, b: (a > b) - (a < b) __all__ = ("TreeModifier", "Tree") -def git_cmp(t1: T_Tree_cache, t2: T_Tree_cache) -> int: +def git_cmp(t1: TreeCacheTup, t2: TreeCacheTup) -> int: a, b = t1[2], t2[2] - assert isinstance(a, str) and isinstance(b, str) # Need as mypy 9.0 cannot unpack TypeVar properly + # assert isinstance(a, str) and isinstance(b, str) len_a, len_b = len(a), len(b) min_len = min(len_a, len_b) min_cmp = cmp(a[:min_len], b[:min_len]) @@ -55,8 +60,8 @@ def git_cmp(t1: T_Tree_cache, t2: T_Tree_cache) -> int: return len_a - len_b -def merge_sort(a: List[T_Tree_cache], - cmp: Callable[[T_Tree_cache, T_Tree_cache], int]) -> None: +def merge_sort(a: List[TreeCacheTup], + cmp: Callable[[TreeCacheTup, TreeCacheTup], int]) -> None: if len(a) < 2: return None @@ -91,7 +96,7 @@ def merge_sort(a: List[T_Tree_cache], k = k + 1 -class TreeModifier(Generic[T_Tree_cache], object): +class TreeModifier(object): """A utility class providing methods to alter the underlying cache in a list-like fashion. @@ -99,7 +104,7 @@ class TreeModifier(Generic[T_Tree_cache], object): the cache of a tree, will be sorted. Assuring it will be in a serializable state""" __slots__ = '_cache' - def __init__(self, cache: List[T_Tree_cache]) -> None: + def __init__(self, cache: List[TreeCacheTup]) -> None: self._cache = cache def _index_by_name(self, name: str) -> int: @@ -141,11 +146,8 @@ class TreeModifier(Generic[T_Tree_cache], object): sha = to_bin_sha(sha) index = self._index_by_name(name) - def is_tree_cache(inp: Tuple[bytes, int, str]) -> TypeGuard[T_Tree_cache]: - return isinstance(inp[0], bytes) and isinstance(inp[1], int) and isinstance([inp], str) - item = (sha, mode, name) - assert is_tree_cache(item) + # assert is_tree_cache(item) if index == -1: self._cache.append(item) @@ -167,7 +169,7 @@ class TreeModifier(Generic[T_Tree_cache], object): For more information on the parameters, see ``add`` :param binsha: 20 byte binary sha""" assert isinstance(binsha, bytes) and isinstance(mode, int) and isinstance(name, str) - tree_cache = cast(T_Tree_cache, (binsha, mode, name)) + tree_cache = (binsha, mode, name) self._cache.append(tree_cache) @@ -180,7 +182,7 @@ class TreeModifier(Generic[T_Tree_cache], object): #} END mutators -class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): +class Tree(IndexObject, git_diff.Diffable, util.Traversable, util.Serializable): """Tree objects represent an ordered list of Blobs and other Trees. @@ -216,7 +218,6 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): def _get_intermediate_items(cls, index_object: 'Tree', ) -> Union[Tuple['Tree', ...], Tuple[()]]: if index_object.type == "tree": - index_object = cast('Tree', index_object) return tuple(index_object._iter_convert_to_object(index_object._cache)) return () @@ -224,12 +225,12 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): if attr == "_cache": # Set the data when we need it ostream = self.repo.odb.stream(self.binsha) - self._cache: List[Tuple[bytes, int, str]] = tree_entries_from_data(ostream.read()) + self._cache: List[TreeCacheTup] = tree_entries_from_data(ostream.read()) else: super(Tree, self)._set_cache_(attr) # END handle attribute - def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]] + def _iter_convert_to_object(self, iterable: Iterable[TreeCacheTup] ) -> Iterator[IndexObjUnion]: """Iterable yields tuples of (binsha, mode, name), which will be converted to the respective object representation""" @@ -324,6 +325,14 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): super(Tree, self).traverse(predicate, prune, depth, # type: ignore branch_first, visit_once, ignore_self)) + def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[IndexObjUnion]: + """ + :return: IterableList with the results of the traversal as produced by + traverse() + Tree -> IterableList[Union['Submodule', 'Tree', 'Blob']] + """ + return super(Tree, self).list_traverse(* args, **kwargs) + # List protocol def __getslice__(self, i: int, j: int) -> List[IndexObjUnion]: diff --git a/git/objects/util.py b/git/objects/util.py index 0b449b7b..fbe3d9de 100644 --- a/git/objects/util.py +++ b/git/objects/util.py @@ -23,7 +23,7 @@ from datetime import datetime, timedelta, tzinfo from typing import (Any, Callable, Deque, Iterator, NamedTuple, overload, Sequence, TYPE_CHECKING, Tuple, Type, TypeVar, Union, cast) -from git.types import Literal, TypeGuard +from git.types import Has_id_attribute, Literal if TYPE_CHECKING: from io import BytesIO, StringIO @@ -32,8 +32,15 @@ if TYPE_CHECKING: from .tag import TagObject from .tree import Tree, TraversedTreeTup from subprocess import Popen + from .submodule.base import Submodule + + +class TraverseNT(NamedTuple): + depth: int + item: Union['Traversable', 'Blob'] + src: Union['Traversable', None] + - T_TIobj = TypeVar('T_TIobj', bound='TraversableIterableObj') # for TraversableIterableObj.traverse() TraversedTup = Union[Tuple[Union['Traversable', None], 'Traversable'], # for commit, submodule @@ -306,20 +313,29 @@ class Traversable(object): """ raise NotImplementedError("To be implemented in subclass") - def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList['TraversableIterableObj']: + def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[Union['Commit', 'Submodule', 'Tree', 'Blob']]: """ :return: IterableList with the results of the traversal as produced by traverse() - List objects must be IterableObj and Traversable e.g. Commit, Submodule""" - - def is_TraversableIterableObj(inp: 'Traversable') -> TypeGuard['TraversableIterableObj']: - # return isinstance(self, TraversableIterableObj) - # Can it be anythin else? - return isinstance(self, Traversable) - - assert is_TraversableIterableObj(self), f"{type(self)}" - out: IterableList['TraversableIterableObj'] = IterableList(self._id_attribute_) - out.extend(self.traverse(*args, **kwargs)) + Commit -> IterableList['Commit'] + Submodule -> IterableList['Submodule'] + Tree -> IterableList[Union['Submodule', 'Tree', 'Blob']] + """ + # Commit and Submodule have id.__attribute__ as IterableObj + # Tree has id.__attribute__ inherited from IndexObject + if isinstance(self, (TraversableIterableObj, Has_id_attribute)): + id = self._id_attribute_ + else: + id = "" # shouldn't reach here, unless Traversable subclass created with no _id_attribute_ + # could add _id_attribute_ to Traversable, or make all Traversable also Iterable? + + out: IterableList[Union['Commit', 'Submodule', 'Tree', 'Blob']] = IterableList(id) + # overloads in subclasses (mypy does't allow typing self: subclass) + # Union[IterableList['Commit'], IterableList['Submodule'], IterableList[Union['Submodule', 'Tree', 'Blob']]] + + # NOTE: if is_edge=True, self.traverse returns a Tuple, so should be prevented or flattened? + kwargs['as_edge'] = False + out.extend(self.traverse(*args, **kwargs)) # type: ignore return out def traverse(self, @@ -364,15 +380,11 @@ class Traversable(object): Submodule -> Iterator[Submodule, Tuple[Submodule, Submodule]] Tree -> Iterator[Union[Blob, Tree, Submodule, Tuple[Union[Submodule, Tree], Union[Blob, Tree, Submodule]]] - + ignore_self=True is_edge=True -> Iterator[item] ignore_self=True is_edge=False --> Iterator[item] ignore_self=False is_edge=True -> Iterator[item] | Iterator[Tuple[src, item]] ignore_self=False is_edge=False -> Iterator[Tuple[src, item]]""" - class TraverseNT(NamedTuple): - depth: int - item: Union['Traversable', 'Blob'] - src: Union['Traversable', None] visited = set() stack = deque() # type: Deque[TraverseNT] @@ -447,7 +459,10 @@ class TraversableIterableObj(Traversable, IterableObj): TIobj_tuple = Tuple[Union[T_TIobj, None], T_TIobj] - @overload # type: ignore + def list_traverse(self: T_TIobj, *args: Any, **kwargs: Any) -> IterableList[T_TIobj]: # type: ignore[override] + return super(TraversableIterableObj, self).list_traverse(* args, **kwargs) + + @ overload # type: ignore def traverse(self: T_TIobj, predicate: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool], prune: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool], @@ -457,7 +472,7 @@ class TraversableIterableObj(Traversable, IterableObj): ) -> Iterator[T_TIobj]: ... - @overload + @ overload def traverse(self: T_TIobj, predicate: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool], prune: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool], @@ -467,7 +482,7 @@ class TraversableIterableObj(Traversable, IterableObj): ) -> Iterator[Tuple[Union[T_TIobj, None], T_TIobj]]: ... - @overload + @ overload def traverse(self: T_TIobj, predicate: Callable[[Union[T_TIobj, TIobj_tuple], int], bool], prune: Callable[[Union[T_TIobj, TIobj_tuple], int], bool], |