summaryrefslogtreecommitdiff
path: root/django/contrib/postgres/fields/array.py
diff options
context:
space:
mode:
authorMariusz Felisiak <felisiak.mariusz@gmail.com>2019-08-22 11:29:56 +0200
committerCarlton Gibson <carlton@noumenal.es>2019-08-23 10:43:08 +0200
commitb1f669406ff82fcd5fed9f20be258944bd3b3bf9 (patch)
treeaba36fbe2ff6498e07ab5e692fb44a45a2f23fa3 /django/contrib/postgres/fields/array.py
parent33b9b23bbb6c373843ef184d24438c1bdff37c82 (diff)
downloaddjango-b1f669406ff82fcd5fed9f20be258944bd3b3bf9.tar.gz
Reduced code duplication in ArrayField's lookups.
Diffstat (limited to 'django/contrib/postgres/fields/array.py')
-rw-r--r--django/contrib/postgres/fields/array.py35
1 files changed, 15 insertions, 20 deletions
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py
index f85a280b61..73944056d5 100644
--- a/django/contrib/postgres/fields/array.py
+++ b/django/contrib/postgres/fields/array.py
@@ -190,36 +190,31 @@ class ArrayField(CheckFieldDefaultMixin, Field):
})
+class ArrayCastRHSMixin:
+ def process_rhs(self, compiler, connection):
+ rhs, rhs_params = super().process_rhs(compiler, connection)
+ cast_type = self.lhs.output_field.db_type(connection)
+ return '%s::%s' % (rhs, cast_type), rhs_params
+
+
@ArrayField.register_lookup
-class ArrayContains(lookups.DataContains):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
+class ArrayContains(ArrayCastRHSMixin, lookups.DataContains):
+ pass
@ArrayField.register_lookup
-class ArrayContainedBy(lookups.ContainedBy):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
+class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy):
+ pass
@ArrayField.register_lookup
-class ArrayExact(Exact):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
+class ArrayExact(ArrayCastRHSMixin, Exact):
+ pass
@ArrayField.register_lookup
-class ArrayOverlap(lookups.Overlap):
- def as_sql(self, qn, connection):
- sql, params = super().as_sql(qn, connection)
- sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
- return sql, params
+class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap):
+ pass
@ArrayField.register_lookup