diff options
-rw-r--r-- | doc/build/changelog/changelog_11.rst | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/json.py | 50 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/json.py | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 42 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 12 | ||||
-rw-r--r-- | test/sql/test_types.py | 71 |
7 files changed, 190 insertions, 25 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index fe6be3b17..b94104be8 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -29,6 +29,16 @@ to a CAST expression under MySQL. .. change:: + :tags: bug, sql, postgresql, mysql + :tickets: 3765 + + Fixed regression in JSON datatypes where the "literal processor" for + a JSON index value would not be invoked. The native String and Integer + datatypes are now called upon from within the JSONIndexType + and JSONPathType. This is applied to the generic, Postgresql, and + MySQL JSON types and also has a dependency on :ticket:`3766`. + + .. change:: :tags: change, orm Passing False to :meth:`.Query.order_by` in order to cancel diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7ab9fad69..e7e533890 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -763,13 +763,13 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( - self.process(binary.left), - self.process(binary.right)) + self.process(binary.left, **kw), + self.process(binary.right, **kw)) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( - self.process(binary.left), - self.process(binary.right)) + self.process(binary.left, **kw), + self.process(binary.right, **kw)) def visit_concat_op_binary(self, binary, operator, **kw): return "concat(%s, %s)" % (self.process(binary.left), diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py index 3840a7cd6..8dd99bd45 100644 --- a/lib/sqlalchemy/dialects/mysql/json.py +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -31,25 +31,49 @@ class JSON(sqltypes.JSON): pass -class JSONIndexType(sqltypes.JSON.JSONIndexType): + +class _FormatTypeMixin(object): + def _format_value(self, value): + raise NotImplementedError() + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + def process(value): - if isinstance(value, int): - return "$[%s]" % value - else: - return '$."%s"' % value + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value return process + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) -class JSONPathType(sqltypes.JSON.JSONPathType): - def bind_processor(self, dialect): def process(value): - return "$%s" % ( - "".join([ - "[%s]" % elem if isinstance(elem, int) - else '."%s"' % elem for elem in value - ]) - ) + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join([ + "[%s]" % elem if isinstance(elem, int) + else '."%s"' % elem for elem in value + ]) + ) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index b0f0f7cf0..05c4d014d 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -49,10 +49,28 @@ CONTAINED_BY = operators.custom_op( class JSONPathType(sqltypes.JSON.JSONPathType): def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + assert isinstance(value, collections.Sequence) + tokens = [util.text_type(elem)for elem in value] + value = "{%s}" % (", ".join(tokens)) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + def process(value): assert isinstance(value, collections.Sequence) - tokens = [util.text_type(elem) for elem in value] - return "{%s}" % (", ".join(tokens)) + tokens = [util.text_type(elem)for elem in value] + value = "{%s}" % (", ".join(tokens)) + if super_proc: + value = super_proc(value) + return value return process diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 977231336..b55d435ad 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1789,7 +1789,45 @@ class JSON(Indexable, TypeEngine): """ self.none_as_null = none_as_null - class JSONIndexType(TypeEngine): + class JSONElementType(TypeEngine): + """common function for index / path elements in a JSON expression.""" + + _integer = Integer() + _string = String() + + def string_bind_processor(self, dialect): + return self._string._cached_bind_processor(dialect) + + def string_literal_processor(self, dialect): + return self._string._cached_literal_processor(dialect) + + def bind_processor(self, dialect): + int_processor = self._integer._cached_bind_processor(dialect) + string_processor = self.string_bind_processor(dialect) + + def process(value): + if int_processor and isinstance(value, int): + value = int_processor(value) + elif string_processor and isinstance(value, util.string_types): + value = string_processor(value) + return value + + return process + + def literal_processor(self, dialect): + int_processor = self._integer._cached_literal_processor(dialect) + string_processor = self.string_literal_processor(dialect) + + def process(value): + if int_processor and isinstance(value, int): + value = int_processor(value) + elif string_processor and isinstance(value, util.string_types): + value = string_processor(value) + return value + + return process + + class JSONIndexType(JSONElementType): """Placeholder for the datatype of a JSON index value. This allows execution-time processing of JSON index values @@ -1797,7 +1835,7 @@ class JSON(Indexable, TypeEngine): """ - class JSONPathType(TypeEngine): + class JSONPathType(JSONElementType): """Placeholder type for JSON path operations. This allows execution-time processing of a path-based diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index d74ef60da..d85531396 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -736,14 +736,18 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def _test_index_criteria(self, crit, expected): self._criteria_fixture() with config.db.connect() as conn: + stmt = select([self.tables.data_table.c.name]).where(crit) + eq_( - conn.scalar( - select([self.tables.data_table.c.name]). - where(crit) - ), + conn.scalar(stmt), expected ) + literal_sql = str(stmt.compile( + config.db, compile_kwargs={"literal_binds": True})) + + eq_(conn.scalar(literal_sql), expected) + def test_crit_spaces_in_key(self): name = self.tables.data_table.c.name col = self.tables.data_table.c['data'] diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 49a1d8f15..3374a6721 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1630,6 +1630,77 @@ class JSONTest(fixtures.TestBase): None ) + def _dialect_index_fixture(self, int_processor, str_processor): + class MyInt(Integer): + def bind_processor(self, dialect): + return lambda value: value + 10 + + def literal_processor(self, diaect): + return lambda value: str(value + 15) + + class MyString(String): + def bind_processor(self, dialect): + return lambda value: value + "10" + + def literal_processor(self, diaect): + return lambda value: value + "15" + + class MyDialect(default.DefaultDialect): + colspecs = {} + if int_processor: + colspecs[Integer] = MyInt + if str_processor: + colspecs[String] = MyString + + return MyDialect() + + def test_index_bind_proc_int(self): + expr = self.test_table.c.test_column[5] + + int_dialect = self._dialect_index_fixture(True, True) + non_int_dialect = self._dialect_index_fixture(False, True) + + bindproc = expr.right.type._cached_bind_processor(int_dialect) + eq_(bindproc(expr.right.value), 15) + + bindproc = expr.right.type._cached_bind_processor(non_int_dialect) + eq_(bindproc(expr.right.value), 5) + + def test_index_literal_proc_int(self): + expr = self.test_table.c.test_column[5] + + int_dialect = self._dialect_index_fixture(True, True) + non_int_dialect = self._dialect_index_fixture(False, True) + + bindproc = expr.right.type._cached_literal_processor(int_dialect) + eq_(bindproc(expr.right.value), "20") + + bindproc = expr.right.type._cached_literal_processor(non_int_dialect) + eq_(bindproc(expr.right.value), "5") + + def test_index_bind_proc_str(self): + expr = self.test_table.c.test_column['five'] + + str_dialect = self._dialect_index_fixture(True, True) + non_str_dialect = self._dialect_index_fixture(False, False) + + bindproc = expr.right.type._cached_bind_processor(str_dialect) + eq_(bindproc(expr.right.value), 'five10') + + bindproc = expr.right.type._cached_bind_processor(non_str_dialect) + eq_(bindproc(expr.right.value), 'five') + + def test_index_literal_proc_str(self): + expr = self.test_table.c.test_column['five'] + + str_dialect = self._dialect_index_fixture(True, True) + non_str_dialect = self._dialect_index_fixture(False, False) + + bindproc = expr.right.type._cached_literal_processor(str_dialect) + eq_(bindproc(expr.right.value), "five15") + + bindproc = expr.right.type._cached_literal_processor(non_str_dialect) + eq_(bindproc(expr.right.value), "'five'") class ArrayTest(fixtures.TestBase): |