diff options
Diffstat (limited to 'django/contrib/postgres/fields/jsonb.py')
-rw-r--r-- | django/contrib/postgres/fields/jsonb.py | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py new file mode 100644 index 0000000000..eaf3e98d36 --- /dev/null +++ b/django/contrib/postgres/fields/jsonb.py @@ -0,0 +1,99 @@ +import json + +from psycopg2.extras import Json + +from django.contrib.postgres import forms, lookups +from django.core import exceptions +from django.db.models import Field, Transform +from django.utils.translation import ugettext_lazy as _ + +__all__ = ['JSONField'] + + +class JSONField(Field): + empty_strings_allowed = False + description = _('A JSON object') + default_error_messages = { + 'invalid': _("Value must be valid JSON."), + } + + def db_type(self, connection): + return 'jsonb' + + def get_transform(self, name): + transform = super(JSONField, self).get_transform(name) + if transform: + return transform + return KeyTransformFactory(name) + + def get_prep_value(self, value): + if value is not None: + return Json(value) + return value + + def get_prep_lookup(self, lookup_type, value): + if lookup_type in ('has_key', 'has_keys', 'has_any_keys'): + return value + if isinstance(value, (dict, list)): + return Json(value) + return super(JSONField, self).get_prep_lookup(lookup_type, value) + + def validate(self, value, model_instance): + super(JSONField, self).validate(value, model_instance) + try: + json.dumps(value) + except TypeError: + raise exceptions.ValidationError( + self.error_messages['invalid'], + code='invalid', + params={'value': value}, + ) + + def value_to_string(self, obj): + value = self._get_val_from_obj(obj) + return value + + def formfield(self, **kwargs): + defaults = {'form_class': forms.JSONField} + defaults.update(kwargs) + return super(JSONField, self).formfield(**defaults) + + +JSONField.register_lookup(lookups.DataContains) +JSONField.register_lookup(lookups.ContainedBy) +JSONField.register_lookup(lookups.HasKey) +JSONField.register_lookup(lookups.HasKeys) +JSONField.register_lookup(lookups.HasAnyKeys) + + +class KeyTransform(Transform): + + def __init__(self, key_name, *args, **kwargs): + super(KeyTransform, self).__init__(*args, **kwargs) + self.key_name = key_name + + def as_sql(self, compiler, connection): + key_transforms = [self.key_name] + previous = self.lhs + while isinstance(previous, KeyTransform): + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + lhs, params = compiler.compile(previous) + if len(key_transforms) > 1: + return "{} #> %s".format(lhs), [key_transforms] + params + try: + int(self.key_name) + except ValueError: + lookup = "'%s'" % self.key_name + else: + lookup = "%s" % self.key_name + return "%s -> %s" % (lhs, lookup), params + + +class KeyTransformFactory(object): + + def __init__(self, key_name): + self.key_name = key_name + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, *args, **kwargs) |