summaryrefslogtreecommitdiff
path: root/django/contrib/postgres/fields/array.py
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /django/contrib/postgres/fields/array.py
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
downloaddjango-9c19aff7c7561e3a82978a272ecdaad40dda5c00.tar.gz
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/contrib/postgres/fields/array.py')
-rw-r--r--django/contrib/postgres/fields/array.py128
1 files changed, 69 insertions, 59 deletions
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py
index 9c1bb96b61..7269198674 100644
--- a/django/contrib/postgres/fields/array.py
+++ b/django/contrib/postgres/fields/array.py
@@ -12,38 +12,43 @@ from django.utils.translation import gettext_lazy as _
from ..utils import prefix_validation_error
from .utils import AttributeSetter
-__all__ = ['ArrayField']
+__all__ = ["ArrayField"]
class ArrayField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
default_error_messages = {
- 'item_invalid': _('Item %(nth)s in the array did not validate:'),
- 'nested_array_mismatch': _('Nested arrays must have the same length.'),
+ "item_invalid": _("Item %(nth)s in the array did not validate:"),
+ "nested_array_mismatch": _("Nested arrays must have the same length."),
}
- _default_hint = ('list', '[]')
+ _default_hint = ("list", "[]")
def __init__(self, base_field, size=None, **kwargs):
self.base_field = base_field
self.size = size
if self.size:
- self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
+ self.default_validators = [
+ *self.default_validators,
+ ArrayMaxLengthValidator(self.size),
+ ]
# For performance, only add a from_db_value() method if the base field
# implements it.
- if hasattr(self.base_field, 'from_db_value'):
+ if hasattr(self.base_field, "from_db_value"):
self.from_db_value = self._from_db_value
super().__init__(**kwargs)
@property
def model(self):
try:
- return self.__dict__['model']
+ return self.__dict__["model"]
except KeyError:
- raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
+ raise AttributeError(
+ "'%s' object has no attribute 'model'" % self.__class__.__name__
+ )
@model.setter
def model(self, model):
- self.__dict__['model'] = model
+ self.__dict__["model"] = model
self.base_field.model = model
@classmethod
@@ -55,21 +60,23 @@ class ArrayField(CheckFieldDefaultMixin, Field):
if self.base_field.remote_field:
errors.append(
checks.Error(
- 'Base field for array cannot be a related field.',
+ "Base field for array cannot be a related field.",
obj=self,
- id='postgres.E002'
+ id="postgres.E002",
)
)
else:
# Remove the field name checks as they are not needed here.
base_errors = self.base_field.check()
if base_errors:
- messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
+ messages = "\n ".join(
+ "%s (%s)" % (error.msg, error.id) for error in base_errors
+ )
errors.append(
checks.Error(
- 'Base field for array has errors:\n %s' % messages,
+ "Base field for array has errors:\n %s" % messages,
obj=self,
- id='postgres.E001'
+ id="postgres.E001",
)
)
return errors
@@ -80,32 +87,37 @@ class ArrayField(CheckFieldDefaultMixin, Field):
@property
def description(self):
- return 'Array of %s' % self.base_field.description
+ return "Array of %s" % self.base_field.description
def db_type(self, connection):
- size = self.size or ''
- return '%s[%s]' % (self.base_field.db_type(connection), size)
+ size = self.size or ""
+ return "%s[%s]" % (self.base_field.db_type(connection), size)
def cast_db_type(self, connection):
- size = self.size or ''
- return '%s[%s]' % (self.base_field.cast_db_type(connection), size)
+ size = self.size or ""
+ return "%s[%s]" % (self.base_field.cast_db_type(connection), size)
def get_placeholder(self, value, compiler, connection):
- return '%s::{}'.format(self.db_type(connection))
+ return "%s::{}".format(self.db_type(connection))
def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, (list, tuple)):
- return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
+ return [
+ self.base_field.get_db_prep_value(i, connection, prepared=False)
+ for i in value
+ ]
return value
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- if path == 'django.contrib.postgres.fields.array.ArrayField':
- path = 'django.contrib.postgres.fields.ArrayField'
- kwargs.update({
- 'base_field': self.base_field.clone(),
- 'size': self.size,
- })
+ if path == "django.contrib.postgres.fields.array.ArrayField":
+ path = "django.contrib.postgres.fields.ArrayField"
+ kwargs.update(
+ {
+ "base_field": self.base_field.clone(),
+ "size": self.size,
+ }
+ )
return name, path, args, kwargs
def to_python(self, value):
@@ -140,7 +152,7 @@ class ArrayField(CheckFieldDefaultMixin, Field):
transform = super().get_transform(name)
if transform:
return transform
- if '_' not in name:
+ if "_" not in name:
try:
index = int(name)
except ValueError:
@@ -149,7 +161,7 @@ class ArrayField(CheckFieldDefaultMixin, Field):
index += 1 # postgres uses 1-indexing
return IndexTransformFactory(index, self.base_field)
try:
- start, end = name.split('_')
+ start, end = name.split("_")
start = int(start) + 1
end = int(end) # don't add one here because postgres slices are weird
except ValueError:
@@ -165,15 +177,15 @@ class ArrayField(CheckFieldDefaultMixin, Field):
except exceptions.ValidationError as error:
raise prefix_validation_error(
error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
+ prefix=self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
)
if isinstance(self.base_field, ArrayField):
if len({len(i) for i in value}) > 1:
raise exceptions.ValidationError(
- self.error_messages['nested_array_mismatch'],
- code='nested_array_mismatch',
+ self.error_messages["nested_array_mismatch"],
+ code="nested_array_mismatch",
)
def run_validators(self, value):
@@ -184,18 +196,20 @@ class ArrayField(CheckFieldDefaultMixin, Field):
except exceptions.ValidationError as error:
raise prefix_validation_error(
error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
+ prefix=self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': SimpleArrayField,
- 'base_field': self.base_field.formfield(),
- 'max_length': self.size,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": SimpleArrayField,
+ "base_field": self.base_field.formfield(),
+ "max_length": self.size,
+ **kwargs,
+ }
+ )
class ArrayRHSMixin:
@@ -203,21 +217,21 @@ class ArrayRHSMixin:
if isinstance(rhs, (tuple, list)):
expressions = []
for value in rhs:
- if not hasattr(value, 'resolve_expression'):
+ if not hasattr(value, "resolve_expression"):
field = lhs.output_field
value = Value(field.base_field.get_prep_value(value))
expressions.append(value)
rhs = Func(
*expressions,
- function='ARRAY',
- template='%(function)s[%(expressions)s]',
+ function="ARRAY",
+ template="%(function)s[%(expressions)s]",
)
super().__init__(lhs, rhs)
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
cast_type = self.lhs.output_field.cast_db_type(connection)
- return '%s::%s' % (rhs, cast_type), rhs_params
+ return "%s::%s" % (rhs, cast_type), rhs_params
@ArrayField.register_lookup
@@ -242,29 +256,29 @@ class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
@ArrayField.register_lookup
class ArrayLenTransform(Transform):
- lookup_name = 'len'
+ lookup_name = "len"
output_field = IntegerField()
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
# Distinguish NULL and empty arrays
return (
- 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
- 'coalesce(array_length(%(lhs)s, 1), 0) END'
- ) % {'lhs': lhs}, params
+ "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
+ "coalesce(array_length(%(lhs)s, 1), 0) END"
+ ) % {"lhs": lhs}, params
@ArrayField.register_lookup
class ArrayInLookup(In):
def get_prep_lookup(self):
values = super().get_prep_lookup()
- if hasattr(values, 'resolve_expression'):
+ if hasattr(values, "resolve_expression"):
return values
# In.process_rhs() expects values to be hashable, so convert lists
# to tuples.
prepared_values = []
for value in values:
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
prepared_values.append(value)
else:
prepared_values.append(tuple(value))
@@ -272,7 +286,6 @@ class ArrayInLookup(In):
class IndexTransform(Transform):
-
def __init__(self, index, base_field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.index = index
@@ -280,7 +293,7 @@ class IndexTransform(Transform):
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
- return '%s[%%s]' % lhs, params + [self.index]
+ return "%s[%%s]" % lhs, params + [self.index]
@property
def output_field(self):
@@ -288,7 +301,6 @@ class IndexTransform(Transform):
class IndexTransformFactory:
-
def __init__(self, index, base_field):
self.index = index
self.base_field = base_field
@@ -298,7 +310,6 @@ class IndexTransformFactory:
class SliceTransform(Transform):
-
def __init__(self, start, end, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start = start
@@ -306,11 +317,10 @@ class SliceTransform(Transform):
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
- return '%s[%%s:%%s]' % lhs, params + [self.start, self.end]
+ return "%s[%%s:%%s]" % lhs, params + [self.start, self.end]
class SliceTransformFactory:
-
def __init__(self, start, end):
self.start = start
self.end = end