summaryrefslogtreecommitdiff
path: root/django/contrib/postgres/fields/jsonb.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/contrib/postgres/fields/jsonb.py')
-rw-r--r--django/contrib/postgres/fields/jsonb.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py
new file mode 100644
index 0000000000..eaf3e98d36
--- /dev/null
+++ b/django/contrib/postgres/fields/jsonb.py
@@ -0,0 +1,99 @@
+import json
+
+from psycopg2.extras import Json
+
+from django.contrib.postgres import forms, lookups
+from django.core import exceptions
+from django.db.models import Field, Transform
+from django.utils.translation import ugettext_lazy as _
+
+__all__ = ['JSONField']
+
+
+class JSONField(Field):
+ empty_strings_allowed = False
+ description = _('A JSON object')
+ default_error_messages = {
+ 'invalid': _("Value must be valid JSON."),
+ }
+
+ def db_type(self, connection):
+ return 'jsonb'
+
+ def get_transform(self, name):
+ transform = super(JSONField, self).get_transform(name)
+ if transform:
+ return transform
+ return KeyTransformFactory(name)
+
+ def get_prep_value(self, value):
+ if value is not None:
+ return Json(value)
+ return value
+
+ def get_prep_lookup(self, lookup_type, value):
+ if lookup_type in ('has_key', 'has_keys', 'has_any_keys'):
+ return value
+ if isinstance(value, (dict, list)):
+ return Json(value)
+ return super(JSONField, self).get_prep_lookup(lookup_type, value)
+
+ def validate(self, value, model_instance):
+ super(JSONField, self).validate(value, model_instance)
+ try:
+ json.dumps(value)
+ except TypeError:
+ raise exceptions.ValidationError(
+ self.error_messages['invalid'],
+ code='invalid',
+ params={'value': value},
+ )
+
+ def value_to_string(self, obj):
+ value = self._get_val_from_obj(obj)
+ return value
+
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.JSONField}
+ defaults.update(kwargs)
+ return super(JSONField, self).formfield(**defaults)
+
+
+JSONField.register_lookup(lookups.DataContains)
+JSONField.register_lookup(lookups.ContainedBy)
+JSONField.register_lookup(lookups.HasKey)
+JSONField.register_lookup(lookups.HasKeys)
+JSONField.register_lookup(lookups.HasAnyKeys)
+
+
+class KeyTransform(Transform):
+
+ def __init__(self, key_name, *args, **kwargs):
+ super(KeyTransform, self).__init__(*args, **kwargs)
+ self.key_name = key_name
+
+ def as_sql(self, compiler, connection):
+ key_transforms = [self.key_name]
+ previous = self.lhs
+ while isinstance(previous, KeyTransform):
+ key_transforms.insert(0, previous.key_name)
+ previous = previous.lhs
+ lhs, params = compiler.compile(previous)
+ if len(key_transforms) > 1:
+ return "{} #> %s".format(lhs), [key_transforms] + params
+ try:
+ int(self.key_name)
+ except ValueError:
+ lookup = "'%s'" % self.key_name
+ else:
+ lookup = "%s" % self.key_name
+ return "%s -> %s" % (lhs, lookup), params
+
+
+class KeyTransformFactory(object):
+
+ def __init__(self, key_name):
+ self.key_name = key_name
+
+ def __call__(self, *args, **kwargs):
+ return KeyTransform(self.key_name, *args, **kwargs)