summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/changelog_11.rst10
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py8
-rw-r--r--lib/sqlalchemy/dialects/mysql/json.py50
-rw-r--r--lib/sqlalchemy/dialects/postgresql/json.py22
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py42
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py12
-rw-r--r--test/sql/test_types.py71
7 files changed, 190 insertions, 25 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst
index fe6be3b17..b94104be8 100644
--- a/doc/build/changelog/changelog_11.rst
+++ b/doc/build/changelog/changelog_11.rst
@@ -29,6 +29,16 @@
to a CAST expression under MySQL.
.. change::
+ :tags: bug, sql, postgresql, mysql
+ :tickets: 3765
+
+ Fixed regression in JSON datatypes where the "literal processor" for
+ a JSON index value would not be invoked. The native String and Integer
+ datatypes are now called upon from within the JSONIndexType
+ and JSONPathType. This is applied to the generic, Postgresql, and
+ MySQL JSON types and also has a dependency on :ticket:`3766`.
+
+ .. change::
:tags: change, orm
Passing False to :meth:`.Query.order_by` in order to cancel
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 7ab9fad69..e7e533890 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -763,13 +763,13 @@ class MySQLCompiler(compiler.SQLCompiler):
def visit_json_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
- self.process(binary.left),
- self.process(binary.right))
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw))
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
- self.process(binary.left),
- self.process(binary.right))
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw))
def visit_concat_op_binary(self, binary, operator, **kw):
return "concat(%s, %s)" % (self.process(binary.left),
diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py
index 3840a7cd6..8dd99bd45 100644
--- a/lib/sqlalchemy/dialects/mysql/json.py
+++ b/lib/sqlalchemy/dialects/mysql/json.py
@@ -31,25 +31,49 @@ class JSON(sqltypes.JSON):
pass
-class JSONIndexType(sqltypes.JSON.JSONIndexType):
+
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
def process(value):
- if isinstance(value, int):
- return "$[%s]" % value
- else:
- return '$."%s"' % value
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
return process
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
-class JSONPathType(sqltypes.JSON.JSONPathType):
- def bind_processor(self, dialect):
def process(value):
- return "$%s" % (
- "".join([
- "[%s]" % elem if isinstance(elem, int)
- else '."%s"' % elem for elem in value
- ])
- )
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join([
+ "[%s]" % elem if isinstance(elem, int)
+ else '."%s"' % elem for elem in value
+ ])
+ )
diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py
index b0f0f7cf0..05c4d014d 100644
--- a/lib/sqlalchemy/dialects/postgresql/json.py
+++ b/lib/sqlalchemy/dialects/postgresql/json.py
@@ -49,10 +49,28 @@ CONTAINED_BY = operators.custom_op(
class JSONPathType(sqltypes.JSON.JSONPathType):
def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ assert isinstance(value, collections.Sequence)
+ tokens = [util.text_type(elem)for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
def process(value):
assert isinstance(value, collections.Sequence)
- tokens = [util.text_type(elem) for elem in value]
- return "{%s}" % (", ".join(tokens))
+ tokens = [util.text_type(elem)for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
return process
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 977231336..b55d435ad 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -1789,7 +1789,45 @@ class JSON(Indexable, TypeEngine):
"""
self.none_as_null = none_as_null
- class JSONIndexType(TypeEngine):
+ class JSONElementType(TypeEngine):
+ """common function for index / path elements in a JSON expression."""
+
+ _integer = Integer()
+ _string = String()
+
+ def string_bind_processor(self, dialect):
+ return self._string._cached_bind_processor(dialect)
+
+ def string_literal_processor(self, dialect):
+ return self._string._cached_literal_processor(dialect)
+
+ def bind_processor(self, dialect):
+ int_processor = self._integer._cached_bind_processor(dialect)
+ string_processor = self.string_bind_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ int_processor = self._integer._cached_literal_processor(dialect)
+ string_processor = self.string_literal_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ class JSONIndexType(JSONElementType):
"""Placeholder for the datatype of a JSON index value.
This allows execution-time processing of JSON index values
@@ -1797,7 +1835,7 @@ class JSON(Indexable, TypeEngine):
"""
- class JSONPathType(TypeEngine):
+ class JSONPathType(JSONElementType):
"""Placeholder type for JSON path operations.
This allows execution-time processing of a path-based
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index d74ef60da..d85531396 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -736,14 +736,18 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
def _test_index_criteria(self, crit, expected):
self._criteria_fixture()
with config.db.connect() as conn:
+ stmt = select([self.tables.data_table.c.name]).where(crit)
+
eq_(
- conn.scalar(
- select([self.tables.data_table.c.name]).
- where(crit)
- ),
+ conn.scalar(stmt),
expected
)
+ literal_sql = str(stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}))
+
+ eq_(conn.scalar(literal_sql), expected)
+
def test_crit_spaces_in_key(self):
name = self.tables.data_table.c.name
col = self.tables.data_table.c['data']
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 49a1d8f15..3374a6721 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1630,6 +1630,77 @@ class JSONTest(fixtures.TestBase):
None
)
+ def _dialect_index_fixture(self, int_processor, str_processor):
+ class MyInt(Integer):
+ def bind_processor(self, dialect):
+ return lambda value: value + 10
+
+ def literal_processor(self, diaect):
+ return lambda value: str(value + 15)
+
+ class MyString(String):
+ def bind_processor(self, dialect):
+ return lambda value: value + "10"
+
+ def literal_processor(self, diaect):
+ return lambda value: value + "15"
+
+ class MyDialect(default.DefaultDialect):
+ colspecs = {}
+ if int_processor:
+ colspecs[Integer] = MyInt
+ if str_processor:
+ colspecs[String] = MyString
+
+ return MyDialect()
+
+ def test_index_bind_proc_int(self):
+ expr = self.test_table.c.test_column[5]
+
+ int_dialect = self._dialect_index_fixture(True, True)
+ non_int_dialect = self._dialect_index_fixture(False, True)
+
+ bindproc = expr.right.type._cached_bind_processor(int_dialect)
+ eq_(bindproc(expr.right.value), 15)
+
+ bindproc = expr.right.type._cached_bind_processor(non_int_dialect)
+ eq_(bindproc(expr.right.value), 5)
+
+ def test_index_literal_proc_int(self):
+ expr = self.test_table.c.test_column[5]
+
+ int_dialect = self._dialect_index_fixture(True, True)
+ non_int_dialect = self._dialect_index_fixture(False, True)
+
+ bindproc = expr.right.type._cached_literal_processor(int_dialect)
+ eq_(bindproc(expr.right.value), "20")
+
+ bindproc = expr.right.type._cached_literal_processor(non_int_dialect)
+ eq_(bindproc(expr.right.value), "5")
+
+ def test_index_bind_proc_str(self):
+ expr = self.test_table.c.test_column['five']
+
+ str_dialect = self._dialect_index_fixture(True, True)
+ non_str_dialect = self._dialect_index_fixture(False, False)
+
+ bindproc = expr.right.type._cached_bind_processor(str_dialect)
+ eq_(bindproc(expr.right.value), 'five10')
+
+ bindproc = expr.right.type._cached_bind_processor(non_str_dialect)
+ eq_(bindproc(expr.right.value), 'five')
+
+ def test_index_literal_proc_str(self):
+ expr = self.test_table.c.test_column['five']
+
+ str_dialect = self._dialect_index_fixture(True, True)
+ non_str_dialect = self._dialect_index_fixture(False, False)
+
+ bindproc = expr.right.type._cached_literal_processor(str_dialect)
+ eq_(bindproc(expr.right.value), "five15")
+
+ bindproc = expr.right.type._cached_literal_processor(non_str_dialect)
+ eq_(bindproc(expr.right.value), "'five'")
class ArrayTest(fixtures.TestBase):