diff options
author | django-bot <ops@djangoproject.com> | 2022-02-03 20:24:19 +0100 |
---|---|---|
committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-02-07 20:37:05 +0100 |
commit | 9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch) | |
tree | f0506b668a013d0063e5fba3dbf4863b466713ba /django/contrib/postgres/fields/array.py | |
parent | f68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff) | |
download | django-9c19aff7c7561e3a82978a272ecdaad40dda5c00.tar.gz |
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/contrib/postgres/fields/array.py')
-rw-r--r-- | django/contrib/postgres/fields/array.py | 128 |
1 files changed, 69 insertions, 59 deletions
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 9c1bb96b61..7269198674 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -12,38 +12,43 @@ from django.utils.translation import gettext_lazy as _ from ..utils import prefix_validation_error from .utils import AttributeSetter -__all__ = ['ArrayField'] +__all__ = ["ArrayField"] class ArrayField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False default_error_messages = { - 'item_invalid': _('Item %(nth)s in the array did not validate:'), - 'nested_array_mismatch': _('Nested arrays must have the same length.'), + "item_invalid": _("Item %(nth)s in the array did not validate:"), + "nested_array_mismatch": _("Nested arrays must have the same length."), } - _default_hint = ('list', '[]') + _default_hint = ("list", "[]") def __init__(self, base_field, size=None, **kwargs): self.base_field = base_field self.size = size if self.size: - self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)] + self.default_validators = [ + *self.default_validators, + ArrayMaxLengthValidator(self.size), + ] # For performance, only add a from_db_value() method if the base field # implements it. - if hasattr(self.base_field, 'from_db_value'): + if hasattr(self.base_field, "from_db_value"): self.from_db_value = self._from_db_value super().__init__(**kwargs) @property def model(self): try: - return self.__dict__['model'] + return self.__dict__["model"] except KeyError: - raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__) + raise AttributeError( + "'%s' object has no attribute 'model'" % self.__class__.__name__ + ) @model.setter def model(self, model): - self.__dict__['model'] = model + self.__dict__["model"] = model self.base_field.model = model @classmethod @@ -55,21 +60,23 @@ class ArrayField(CheckFieldDefaultMixin, Field): if self.base_field.remote_field: errors.append( checks.Error( - 'Base field for array cannot be a related field.', + "Base field for array cannot be a related field.", obj=self, - id='postgres.E002' + id="postgres.E002", ) ) else: # Remove the field name checks as they are not needed here. base_errors = self.base_field.check() if base_errors: - messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors) + messages = "\n ".join( + "%s (%s)" % (error.msg, error.id) for error in base_errors + ) errors.append( checks.Error( - 'Base field for array has errors:\n %s' % messages, + "Base field for array has errors:\n %s" % messages, obj=self, - id='postgres.E001' + id="postgres.E001", ) ) return errors @@ -80,32 +87,37 @@ class ArrayField(CheckFieldDefaultMixin, Field): @property def description(self): - return 'Array of %s' % self.base_field.description + return "Array of %s" % self.base_field.description def db_type(self, connection): - size = self.size or '' - return '%s[%s]' % (self.base_field.db_type(connection), size) + size = self.size or "" + return "%s[%s]" % (self.base_field.db_type(connection), size) def cast_db_type(self, connection): - size = self.size or '' - return '%s[%s]' % (self.base_field.cast_db_type(connection), size) + size = self.size or "" + return "%s[%s]" % (self.base_field.cast_db_type(connection), size) def get_placeholder(self, value, compiler, connection): - return '%s::{}'.format(self.db_type(connection)) + return "%s::{}".format(self.db_type(connection)) def get_db_prep_value(self, value, connection, prepared=False): if isinstance(value, (list, tuple)): - return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value] + return [ + self.base_field.get_db_prep_value(i, connection, prepared=False) + for i in value + ] return value def deconstruct(self): name, path, args, kwargs = super().deconstruct() - if path == 'django.contrib.postgres.fields.array.ArrayField': - path = 'django.contrib.postgres.fields.ArrayField' - kwargs.update({ - 'base_field': self.base_field.clone(), - 'size': self.size, - }) + if path == "django.contrib.postgres.fields.array.ArrayField": + path = "django.contrib.postgres.fields.ArrayField" + kwargs.update( + { + "base_field": self.base_field.clone(), + "size": self.size, + } + ) return name, path, args, kwargs def to_python(self, value): @@ -140,7 +152,7 @@ class ArrayField(CheckFieldDefaultMixin, Field): transform = super().get_transform(name) if transform: return transform - if '_' not in name: + if "_" not in name: try: index = int(name) except ValueError: @@ -149,7 +161,7 @@ class ArrayField(CheckFieldDefaultMixin, Field): index += 1 # postgres uses 1-indexing return IndexTransformFactory(index, self.base_field) try: - start, end = name.split('_') + start, end = name.split("_") start = int(start) + 1 end = int(end) # don't add one here because postgres slices are weird except ValueError: @@ -165,15 +177,15 @@ class ArrayField(CheckFieldDefaultMixin, Field): except exceptions.ValidationError as error: raise prefix_validation_error( error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, ) if isinstance(self.base_field, ArrayField): if len({len(i) for i in value}) > 1: raise exceptions.ValidationError( - self.error_messages['nested_array_mismatch'], - code='nested_array_mismatch', + self.error_messages["nested_array_mismatch"], + code="nested_array_mismatch", ) def run_validators(self, value): @@ -184,18 +196,20 @@ class ArrayField(CheckFieldDefaultMixin, Field): except exceptions.ValidationError as error: raise prefix_validation_error( error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, ) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': SimpleArrayField, - 'base_field': self.base_field.formfield(), - 'max_length': self.size, - **kwargs, - }) + return super().formfield( + **{ + "form_class": SimpleArrayField, + "base_field": self.base_field.formfield(), + "max_length": self.size, + **kwargs, + } + ) class ArrayRHSMixin: @@ -203,21 +217,21 @@ class ArrayRHSMixin: if isinstance(rhs, (tuple, list)): expressions = [] for value in rhs: - if not hasattr(value, 'resolve_expression'): + if not hasattr(value, "resolve_expression"): field = lhs.output_field value = Value(field.base_field.get_prep_value(value)) expressions.append(value) rhs = Func( *expressions, - function='ARRAY', - template='%(function)s[%(expressions)s]', + function="ARRAY", + template="%(function)s[%(expressions)s]", ) super().__init__(lhs, rhs) def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) cast_type = self.lhs.output_field.cast_db_type(connection) - return '%s::%s' % (rhs, cast_type), rhs_params + return "%s::%s" % (rhs, cast_type), rhs_params @ArrayField.register_lookup @@ -242,29 +256,29 @@ class ArrayOverlap(ArrayRHSMixin, lookups.Overlap): @ArrayField.register_lookup class ArrayLenTransform(Transform): - lookup_name = 'len' + lookup_name = "len" output_field = IntegerField() def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) # Distinguish NULL and empty arrays return ( - 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE ' - 'coalesce(array_length(%(lhs)s, 1), 0) END' - ) % {'lhs': lhs}, params + "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " + "coalesce(array_length(%(lhs)s, 1), 0) END" + ) % {"lhs": lhs}, params @ArrayField.register_lookup class ArrayInLookup(In): def get_prep_lookup(self): values = super().get_prep_lookup() - if hasattr(values, 'resolve_expression'): + if hasattr(values, "resolve_expression"): return values # In.process_rhs() expects values to be hashable, so convert lists # to tuples. prepared_values = [] for value in values: - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): prepared_values.append(value) else: prepared_values.append(tuple(value)) @@ -272,7 +286,6 @@ class ArrayInLookup(In): class IndexTransform(Transform): - def __init__(self, index, base_field, *args, **kwargs): super().__init__(*args, **kwargs) self.index = index @@ -280,7 +293,7 @@ class IndexTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '%s[%%s]' % lhs, params + [self.index] + return "%s[%%s]" % lhs, params + [self.index] @property def output_field(self): @@ -288,7 +301,6 @@ class IndexTransform(Transform): class IndexTransformFactory: - def __init__(self, index, base_field): self.index = index self.base_field = base_field @@ -298,7 +310,6 @@ class IndexTransformFactory: class SliceTransform(Transform): - def __init__(self, start, end, *args, **kwargs): super().__init__(*args, **kwargs) self.start = start @@ -306,11 +317,10 @@ class SliceTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '%s[%%s:%%s]' % lhs, params + [self.start, self.end] + return "%s[%%s:%%s]" % lhs, params + [self.start, self.end] class SliceTransformFactory: - def __init__(self, start, end): self.start = start self.end = end |