summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CREDITS1
-rw-r--r--semantic_version/django_fields.py2
-rw-r--r--tests/test_django.py45
3 files changed, 47 insertions, 1 deletions
diff --git a/CREDITS b/CREDITS
index c700c77..3efe47a 100644
--- a/CREDITS
+++ b/CREDITS
@@ -20,6 +20,7 @@ The project has received contributions from (in alphabetical order):
* Raphaƫl Barrois <raphael.barrois+semver@polytechnique.org> (https://github.com/rbarrois)
* Michael Hrivnak <mhrivnak@hrivnak.org> (https://github.com/mhrivnak)
* Rick Eyre <rick.eyre@outlook.com> (https://github.com/rickeyre)
+* Hugo Rodger-Brown <hugo@yunojuno.com> (https://github.com/yunojuno)
Contributor license agreement
diff --git a/semantic_version/django_fields.py b/semantic_version/django_fields.py
index a0c6979..740c208 100644
--- a/semantic_version/django_fields.py
+++ b/semantic_version/django_fields.py
@@ -18,7 +18,7 @@ class BaseSemVerField(models.CharField):
super(BaseSemVerField, self).__init__(*args, **kwargs)
def get_prep_value(self, obj):
- return str(obj)
+ return None if obj is None else str(obj)
def get_db_prep_value(self, value, connection, prepared=False):
if not prepared:
diff --git a/tests/test_django.py b/tests/test_django.py
index 94e2420..8646764 100644
--- a/tests/test_django.py
+++ b/tests/test_django.py
@@ -30,6 +30,11 @@ if django_loaded and django.VERSION < (1, 7): # pragma: no cover
except ImportError:
pass
+# the refresh_from_db method only came in with 1.8, so in order to make this
+# work will all supported versions we have our own function
+def refresh_from_db(obj):
+ return obj.__class__.objects.get(id=obj.id)
+
@unittest.skipIf(not django_loaded, "Django not installed")
class DjangoFieldTestCase(unittest.TestCase):
@@ -48,6 +53,46 @@ class DjangoFieldTestCase(unittest.TestCase):
obj.full_clean()
+ def test_version_save(self):
+ """Test saving object with a VersionField."""
+ # first test with a null value
+ obj = models.PartialVersionModel()
+ self.assertIsNone(obj.id)
+ self.assertIsNone(obj.optional)
+ obj.save()
+
+ # now retrieve from db
+ obj = refresh_from_db(obj)
+ self.assertIsNotNone(obj.id)
+ self.assertIsNone(obj.optional_spec)
+
+ # now set to something that is not null
+ spec = semantic_version.Spec('==0,!=0.2')
+ obj.optional_spec = spec
+ obj.save()
+ obj = refresh_from_db(obj)
+ self.assertEqual(obj.optional_spec, spec)
+
+ def test_spec_save(self):
+ """Test saving object with a SpecField."""
+ # first test with a null value
+ obj = models.PartialVersionModel()
+ self.assertIsNone(obj.id)
+ self.assertIsNone(obj.optional_spec)
+ obj.save()
+
+ # now retrieve from db
+ obj.refresh_from_db()
+ self.assertIsNotNone(obj.id)
+ self.assertIsNone(obj.optional_spec)
+
+ # now set to something that is not null
+ spec = semantic_version.Spec('==0,!=0.2')
+ obj.optional_spec = spec
+ obj.save()
+ obj.refresh_from_db()
+ self.assertEqual(obj.optional_spec, spec)
+
def test_partial_spec(self):
obj = models.VersionModel(version='0.1.1', spec='==0,!=0.2')
self.assertEqual(semantic_version.Version('0.1.1'), obj.version)