From 2cbbee3154d9011cee873ae3a020cd17c669f6df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Tue, 28 Feb 2023 11:51:59 +0100 Subject: Simplify subclassing Version Fixes: #112 --- ChangeLog | 6 ++++++ semantic_version/base.py | 24 ++++++++++++------------ tests/test_base.py | 9 +++++++++ 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/ChangeLog b/ChangeLog index 6c0f593..2bf2e2c 100644 --- a/ChangeLog +++ b/ChangeLog @@ -4,6 +4,12 @@ ChangeLog 2.10.1 (unreleased) ------------------- +*Minor:* + + * `112 `_: + Functions returning a new ``Version`` instance reuse the current class, + helping with subclassing. + *Bugfix:* * `141 `_: diff --git a/semantic_version/base.py b/semantic_version/base.py index 1c10155..6be5624 100644 --- a/semantic_version/base.py +++ b/semantic_version/base.py @@ -132,14 +132,14 @@ class Version(object): def next_major(self): if self.prerelease and self.minor == self.patch == 0: - return Version( + return self.__class__( major=self.major, minor=0, patch=0, partial=self.partial, ) else: - return Version( + return self.__class__( major=self.major + 1, minor=0, patch=0, @@ -148,14 +148,14 @@ class Version(object): def next_minor(self): if self.prerelease and self.patch == 0: - return Version( + return self.__class__( major=self.major, minor=self.minor, patch=0, partial=self.partial, ) else: - return Version( + return self.__class__( major=self.major, minor=self.minor + 1, patch=0, @@ -164,14 +164,14 @@ class Version(object): def next_patch(self): if self.prerelease: - return Version( + return self.__class__( major=self.major, minor=self.minor, patch=self.patch, partial=self.partial, ) else: - return Version( + return self.__class__( major=self.major, minor=self.minor, patch=self.patch + 1, @@ -181,7 +181,7 @@ class Version(object): def truncate(self, level='patch'): """Return a new Version object, truncated up to the selected level.""" if level == 'build': - return Version( + return self.__class__( major=self.major, minor=self.minor, patch=self.patch, @@ -190,7 +190,7 @@ class Version(object): partial=self.partial, ) elif level == 'prerelease': - return Version( + return self.__class__( major=self.major, minor=self.minor, patch=self.patch, @@ -198,21 +198,21 @@ class Version(object): partial=self.partial, ) elif level == 'patch': - return Version( + return self.__class__( major=self.major, minor=self.minor, patch=self.patch, partial=self.partial, ) elif level == 'minor': - return Version( + return self.__class__( major=self.major, minor=self.minor, patch=None if self.partial else 0, partial=self.partial, ) elif level == 'major': - return Version( + return self.__class__( major=self.major, minor=None if self.partial else 0, patch=None if self.partial else 0, @@ -266,7 +266,7 @@ class Version(object): ) if match.end() == len(version_string): - return Version(version, partial=partial) + return cls(version, partial=partial) rest = version_string[match.end():] diff --git a/tests/test_base.py b/tests/test_base.py index 73d1b08..4136045 100755 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -405,6 +405,15 @@ class VersionTestCase(unittest.TestCase): self.assertEqual(v.truncate("minor"), base.Version("3.2.0")) self.assertEqual(v.truncate("major"), base.Version("3.0.0")) + def test_subclass(self): + """Custom subclasses of Version returns instances of themselves.""" + class MyVersion(base.Version): + pass + + v = MyVersion("3.2.1-pre") + subv = v.truncate() + self.assertEqual(type(subv), MyVersion) + class SpecItemTestCase(unittest.TestCase): if sys.version_info[0] <= 2: -- cgit v1.2.1