diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/json.py | 50 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/json.py | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 42 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 12 |
5 files changed, 109 insertions, 25 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7ab9fad69..e7e533890 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -763,13 +763,13 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( - self.process(binary.left), - self.process(binary.right)) + self.process(binary.left, **kw), + self.process(binary.right, **kw)) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( - self.process(binary.left), - self.process(binary.right)) + self.process(binary.left, **kw), + self.process(binary.right, **kw)) def visit_concat_op_binary(self, binary, operator, **kw): return "concat(%s, %s)" % (self.process(binary.left), diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py index 3840a7cd6..8dd99bd45 100644 --- a/lib/sqlalchemy/dialects/mysql/json.py +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -31,25 +31,49 @@ class JSON(sqltypes.JSON): pass -class JSONIndexType(sqltypes.JSON.JSONIndexType): + +class _FormatTypeMixin(object): + def _format_value(self, value): + raise NotImplementedError() + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + def process(value): - if isinstance(value, int): - return "$[%s]" % value - else: - return '$."%s"' % value + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value return process + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) -class JSONPathType(sqltypes.JSON.JSONPathType): - def bind_processor(self, dialect): def process(value): - return "$%s" % ( - "".join([ - "[%s]" % elem if isinstance(elem, int) - else '."%s"' % elem for elem in value - ]) - ) + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join([ + "[%s]" % elem if isinstance(elem, int) + else '."%s"' % elem for elem in value + ]) + ) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index b0f0f7cf0..05c4d014d 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -49,10 +49,28 @@ CONTAINED_BY = operators.custom_op( class JSONPathType(sqltypes.JSON.JSONPathType): def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + assert isinstance(value, collections.Sequence) + tokens = [util.text_type(elem)for elem in value] + value = "{%s}" % (", ".join(tokens)) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + def process(value): assert isinstance(value, collections.Sequence) - tokens = [util.text_type(elem) for elem in value] - return "{%s}" % (", ".join(tokens)) + tokens = [util.text_type(elem)for elem in value] + value = "{%s}" % (", ".join(tokens)) + if super_proc: + value = super_proc(value) + return value return process diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 977231336..b55d435ad 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1789,7 +1789,45 @@ class JSON(Indexable, TypeEngine): """ self.none_as_null = none_as_null - class JSONIndexType(TypeEngine): + class JSONElementType(TypeEngine): + """common function for index / path elements in a JSON expression.""" + + _integer = Integer() + _string = String() + + def string_bind_processor(self, dialect): + return self._string._cached_bind_processor(dialect) + + def string_literal_processor(self, dialect): + return self._string._cached_literal_processor(dialect) + + def bind_processor(self, dialect): + int_processor = self._integer._cached_bind_processor(dialect) + string_processor = self.string_bind_processor(dialect) + + def process(value): + if int_processor and isinstance(value, int): + value = int_processor(value) + elif string_processor and isinstance(value, util.string_types): + value = string_processor(value) + return value + + return process + + def literal_processor(self, dialect): + int_processor = self._integer._cached_literal_processor(dialect) + string_processor = self.string_literal_processor(dialect) + + def process(value): + if int_processor and isinstance(value, int): + value = int_processor(value) + elif string_processor and isinstance(value, util.string_types): + value = string_processor(value) + return value + + return process + + class JSONIndexType(JSONElementType): """Placeholder for the datatype of a JSON index value. This allows execution-time processing of JSON index values @@ -1797,7 +1835,7 @@ class JSON(Indexable, TypeEngine): """ - class JSONPathType(TypeEngine): + class JSONPathType(JSONElementType): """Placeholder type for JSON path operations. This allows execution-time processing of a path-based diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index d74ef60da..d85531396 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -736,14 +736,18 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def _test_index_criteria(self, crit, expected): self._criteria_fixture() with config.db.connect() as conn: + stmt = select([self.tables.data_table.c.name]).where(crit) + eq_( - conn.scalar( - select([self.tables.data_table.c.name]). - where(crit) - ), + conn.scalar(stmt), expected ) + literal_sql = str(stmt.compile( + config.db, compile_kwargs={"literal_binds": True})) + + eq_(conn.scalar(literal_sql), expected) + def test_crit_spaces_in_key(self): name = self.tables.data_table.c.name col = self.tables.data_table.c['data'] |