diff options
author | Marc Tamlyn <marc.tamlyn@gmail.com> | 2015-01-10 16:11:15 +0000 |
---|---|---|
committer | Marc Tamlyn <marc.tamlyn@gmail.com> | 2015-01-10 16:18:19 +0000 |
commit | 916e38802f151b34aaca487dc7e928946e81be73 (patch) | |
tree | 891bfce8727a3321b93850d9822250075dcb82e5 /django/contrib/postgres/fields/array.py | |
parent | 74f02557e0183812d6d60e2548985c5c40b3d27b (diff) | |
download | django-916e38802f151b34aaca487dc7e928946e81be73.tar.gz |
Move % addition to lookups, refactor postgres lookups.
These refactorings making overriding some text based lookup names on
other fields (specifically `contains`) much cleaner. It also removes a
bunch of duplication in the contrib.postgres lookups.
Diffstat (limited to 'django/contrib/postgres/fields/array.py')
-rw-r--r-- | django/contrib/postgres/fields/array.py | 50 |
1 files changed, 10 insertions, 40 deletions
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 65f3dc6f6a..318afabd2c 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -1,9 +1,10 @@ import json +from django.contrib.postgres import lookups from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions -from django.db.models import Field, Lookup, Transform, IntegerField +from django.db.models import Field, Transform, IntegerField from django.utils import six from django.utils.translation import string_concat, ugettext_lazy as _ @@ -74,12 +75,6 @@ class ArrayField(Field): return [self.base_field.get_prep_value(i) for i in value] return value - def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): - if lookup_type == 'contains': - return [self.get_prep_value(value)] - return super(ArrayField, self).get_db_prep_lookup(lookup_type, value, - connection, prepared=False) - def deconstruct(self): name, path, args, kwargs = super(ArrayField, self).deconstruct() if path == 'django.contrib.postgres.fields.array.ArrayField': @@ -156,46 +151,21 @@ class ArrayField(Field): @ArrayField.register_lookup -class ArrayContainsLookup(Lookup): - lookup_name = 'contains' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - type_cast = self.lhs.output_field.db_type(connection) - return '%s @> %s::%s' % (lhs, rhs, type_cast), params - - -@ArrayField.register_lookup -class ArrayContainedByLookup(Lookup): - lookup_name = 'contained_by' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s <@ %s' % (lhs, rhs), params - +class ArrayContains(lookups.DataContains): + def as_sql(self, qn, connection): + sql, params = super(ArrayContains, self).as_sql(qn, connection) + sql += '::%s' % self.lhs.output_field.db_type(connection) + return sql, params -@ArrayField.register_lookup -class ArrayOverlapLookup(Lookup): - lookup_name = 'overlap' - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s && %s' % (lhs, rhs), params +ArrayField.register_lookup(lookups.ContainedBy) +ArrayField.register_lookup(lookups.Overlap) @ArrayField.register_lookup class ArrayLenTransform(Transform): lookup_name = 'len' - - @property - def output_field(self): - return IntegerField() + output_field = IntegerField() def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) |