diff options
-rw-r--r-- | CREDITS | 1 | ||||
-rw-r--r-- | semantic_version/django_fields.py | 2 | ||||
-rw-r--r-- | tests/test_django.py | 39 |
3 files changed, 41 insertions, 1 deletions
@@ -20,6 +20,7 @@ The project has received contributions from (in alphabetical order): * Raphaƫl Barrois <raphael.barrois+semver@polytechnique.org> (https://github.com/rbarrois) * Michael Hrivnak <mhrivnak@hrivnak.org> (https://github.com/mhrivnak) * Rick Eyre <rick.eyre@outlook.com> (https://github.com/rickeyre) +* Hugo Rodger-Brown <hugo@yunojuno.com> (https://github.com/yunojuno) Contributor license agreement diff --git a/semantic_version/django_fields.py b/semantic_version/django_fields.py index a0c6979..740c208 100644 --- a/semantic_version/django_fields.py +++ b/semantic_version/django_fields.py @@ -18,7 +18,7 @@ class BaseSemVerField(models.CharField): super(BaseSemVerField, self).__init__(*args, **kwargs) def get_prep_value(self, obj): - return str(obj) + return None if obj is None else str(obj) def get_db_prep_value(self, value, connection, prepared=False): if not prepared: diff --git a/tests/test_django.py b/tests/test_django.py index 94e2420..959b19c 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -30,6 +30,13 @@ if django_loaded and django.VERSION < (1, 7): # pragma: no cover except ImportError: pass +# the refresh_from_db method only came in with 1.8, so in order to make this +# work will all supported versions we have our own function. +def save_and_refresh(obj): + """Saves an object, and refreshes from the database.""" + obj.save() + obj = obj.__class__.objects.get(id=obj.id) + @unittest.skipIf(not django_loaded, "Django not installed") class DjangoFieldTestCase(unittest.TestCase): @@ -48,6 +55,38 @@ class DjangoFieldTestCase(unittest.TestCase): obj.full_clean() + def test_version_save(self): + """Test saving object with a VersionField.""" + # first test with a null value + obj = models.PartialVersionModel() + self.assertIsNone(obj.id) + self.assertIsNone(obj.optional) + save_and_refresh(obj) + self.assertIsNotNone(obj.id) + self.assertIsNone(obj.optional_spec) + + # now set to something that is not null + spec = semantic_version.Spec('==0,!=0.2') + obj.optional_spec = spec + save_and_refresh(obj) + self.assertEqual(obj.optional_spec, spec) + + def test_spec_save(self): + """Test saving object with a SpecField.""" + # first test with a null value + obj = models.PartialVersionModel() + self.assertIsNone(obj.id) + self.assertIsNone(obj.optional_spec) + save_and_refresh(obj) + self.assertIsNotNone(obj.id) + self.assertIsNone(obj.optional_spec) + + # now set to something that is not null + spec = semantic_version.Spec('==0,!=0.2') + obj.optional_spec = spec + save_and_refresh(obj) + self.assertEqual(obj.optional_spec, spec) + def test_partial_spec(self): obj = models.VersionModel(version='0.1.1', spec='==0,!=0.2') self.assertEqual(semantic_version.Version('0.1.1'), obj.version) |