summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py8
-rw-r--r--lib/sqlalchemy/dialects/mysql/json.py50
-rw-r--r--lib/sqlalchemy/dialects/postgresql/json.py22
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py42
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py12
5 files changed, 109 insertions, 25 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 7ab9fad69..e7e533890 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -763,13 +763,13 @@ class MySQLCompiler(compiler.SQLCompiler):
def visit_json_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
- self.process(binary.left),
- self.process(binary.right))
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw))
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
- self.process(binary.left),
- self.process(binary.right))
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw))
def visit_concat_op_binary(self, binary, operator, **kw):
return "concat(%s, %s)" % (self.process(binary.left),
diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py
index 3840a7cd6..8dd99bd45 100644
--- a/lib/sqlalchemy/dialects/mysql/json.py
+++ b/lib/sqlalchemy/dialects/mysql/json.py
@@ -31,25 +31,49 @@ class JSON(sqltypes.JSON):
pass
-class JSONIndexType(sqltypes.JSON.JSONIndexType):
+
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
def process(value):
- if isinstance(value, int):
- return "$[%s]" % value
- else:
- return '$."%s"' % value
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
return process
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
-class JSONPathType(sqltypes.JSON.JSONPathType):
- def bind_processor(self, dialect):
def process(value):
- return "$%s" % (
- "".join([
- "[%s]" % elem if isinstance(elem, int)
- else '."%s"' % elem for elem in value
- ])
- )
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join([
+ "[%s]" % elem if isinstance(elem, int)
+ else '."%s"' % elem for elem in value
+ ])
+ )
diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py
index b0f0f7cf0..05c4d014d 100644
--- a/lib/sqlalchemy/dialects/postgresql/json.py
+++ b/lib/sqlalchemy/dialects/postgresql/json.py
@@ -49,10 +49,28 @@ CONTAINED_BY = operators.custom_op(
class JSONPathType(sqltypes.JSON.JSONPathType):
def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ assert isinstance(value, collections.Sequence)
+ tokens = [util.text_type(elem)for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
def process(value):
assert isinstance(value, collections.Sequence)
- tokens = [util.text_type(elem) for elem in value]
- return "{%s}" % (", ".join(tokens))
+ tokens = [util.text_type(elem)for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
return process
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 977231336..b55d435ad 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -1789,7 +1789,45 @@ class JSON(Indexable, TypeEngine):
"""
self.none_as_null = none_as_null
- class JSONIndexType(TypeEngine):
+ class JSONElementType(TypeEngine):
+ """common function for index / path elements in a JSON expression."""
+
+ _integer = Integer()
+ _string = String()
+
+ def string_bind_processor(self, dialect):
+ return self._string._cached_bind_processor(dialect)
+
+ def string_literal_processor(self, dialect):
+ return self._string._cached_literal_processor(dialect)
+
+ def bind_processor(self, dialect):
+ int_processor = self._integer._cached_bind_processor(dialect)
+ string_processor = self.string_bind_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ int_processor = self._integer._cached_literal_processor(dialect)
+ string_processor = self.string_literal_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ class JSONIndexType(JSONElementType):
"""Placeholder for the datatype of a JSON index value.
This allows execution-time processing of JSON index values
@@ -1797,7 +1835,7 @@ class JSON(Indexable, TypeEngine):
"""
- class JSONPathType(TypeEngine):
+ class JSONPathType(JSONElementType):
"""Placeholder type for JSON path operations.
This allows execution-time processing of a path-based
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index d74ef60da..d85531396 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -736,14 +736,18 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
def _test_index_criteria(self, crit, expected):
self._criteria_fixture()
with config.db.connect() as conn:
+ stmt = select([self.tables.data_table.c.name]).where(crit)
+
eq_(
- conn.scalar(
- select([self.tables.data_table.c.name]).
- where(crit)
- ),
+ conn.scalar(stmt),
expected
)
+ literal_sql = str(stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}))
+
+ eq_(conn.scalar(literal_sql), expected)
+
def test_crit_spaces_in_key(self):
name = self.tables.data_table.c.name
col = self.tables.data_table.c['data']