summaryrefslogtreecommitdiff
path: root/django/contrib/postgres/fields/array.py
diff options
context:
space:
mode:
authorMarc Tamlyn <marc.tamlyn@gmail.com>2015-01-10 16:11:15 +0000
committerMarc Tamlyn <marc.tamlyn@gmail.com>2015-01-10 16:18:19 +0000
commit916e38802f151b34aaca487dc7e928946e81be73 (patch)
tree891bfce8727a3321b93850d9822250075dcb82e5 /django/contrib/postgres/fields/array.py
parent74f02557e0183812d6d60e2548985c5c40b3d27b (diff)
downloaddjango-916e38802f151b34aaca487dc7e928946e81be73.tar.gz
Move % addition to lookups, refactor postgres lookups.
These refactorings making overriding some text based lookup names on other fields (specifically `contains`) much cleaner. It also removes a bunch of duplication in the contrib.postgres lookups.
Diffstat (limited to 'django/contrib/postgres/fields/array.py')
-rw-r--r--django/contrib/postgres/fields/array.py50
1 files changed, 10 insertions, 40 deletions
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py
index 65f3dc6f6a..318afabd2c 100644
--- a/django/contrib/postgres/fields/array.py
+++ b/django/contrib/postgres/fields/array.py
@@ -1,9 +1,10 @@
import json
+from django.contrib.postgres import lookups
from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
-from django.db.models import Field, Lookup, Transform, IntegerField
+from django.db.models import Field, Transform, IntegerField
from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _
@@ -74,12 +75,6 @@ class ArrayField(Field):
return [self.base_field.get_prep_value(i) for i in value]
return value
- def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
- if lookup_type == 'contains':
- return [self.get_prep_value(value)]
- return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
- connection, prepared=False)
-
def deconstruct(self):
name, path, args, kwargs = super(ArrayField, self).deconstruct()
if path == 'django.contrib.postgres.fields.array.ArrayField':
@@ -156,46 +151,21 @@ class ArrayField(Field):
@ArrayField.register_lookup
-class ArrayContainsLookup(Lookup):
- lookup_name = 'contains'
-
- def as_sql(self, compiler, connection):
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = lhs_params + rhs_params
- type_cast = self.lhs.output_field.db_type(connection)
- return '%s @> %s::%s' % (lhs, rhs, type_cast), params
-
-
-@ArrayField.register_lookup
-class ArrayContainedByLookup(Lookup):
- lookup_name = 'contained_by'
-
- def as_sql(self, compiler, connection):
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = lhs_params + rhs_params
- return '%s <@ %s' % (lhs, rhs), params
-
+class ArrayContains(lookups.DataContains):
+ def as_sql(self, qn, connection):
+ sql, params = super(ArrayContains, self).as_sql(qn, connection)
+ sql += '::%s' % self.lhs.output_field.db_type(connection)
+ return sql, params
-@ArrayField.register_lookup
-class ArrayOverlapLookup(Lookup):
- lookup_name = 'overlap'
- def as_sql(self, compiler, connection):
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = lhs_params + rhs_params
- return '%s && %s' % (lhs, rhs), params
+ArrayField.register_lookup(lookups.ContainedBy)
+ArrayField.register_lookup(lookups.Overlap)
@ArrayField.register_lookup
class ArrayLenTransform(Transform):
lookup_name = 'len'
-
- @property
- def output_field(self):
- return IntegerField()
+ output_field = IntegerField()
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)