summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py125
1 files changed, 105 insertions, 20 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8c2699879..45b5eab56 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -59,6 +59,7 @@ from . import crud
from . import elements
from . import functions
from . import operators
+from . import roles
from . import schema
from . import selectable
from . import sqltypes
@@ -686,7 +687,9 @@ class TypeCompiler(util.EnsureKWArg):
# this was a Visitable, but to allow accurate detection of
# column elements this is actually a column element
-class _CompileLabel(elements.CompilerColumnElement):
+class _CompileLabel(
+ roles.BinaryElementRole[Any], elements.CompilerColumnElement
+):
"""lightweight label object which acts as an expression.Label."""
@@ -710,6 +713,44 @@ class _CompileLabel(elements.CompilerColumnElement):
return self
+class ilike_case_insensitive(
+ roles.BinaryElementRole[Any], elements.CompilerColumnElement
+):
+ """produce a wrapping element for a case-insensitive portion of
+ an ILIKE construct.
+
+ The construct usually renders the ``lower()`` function, but on
+ PostgreSQL will pass silently with the assumption that "ILIKE"
+ is being used.
+
+ .. versionadded:: 2.0
+
+ """
+
+ __visit_name__ = "ilike_case_insensitive_operand"
+ __slots__ = "element", "comparator"
+
+ def __init__(self, element):
+ self.element = element
+ self.comparator = element.comparator
+
+ @property
+ def proxy_set(self):
+ return self.element.proxy_set
+
+ @property
+ def type(self):
+ return self.element.type
+
+ def self_group(self, **kw):
+ return self
+
+ def _with_binary_element_type(self, type_):
+ return ilike_case_insensitive(
+ self.element._with_binary_element_type(type_)
+ )
+
+
class SQLCompiler(Compiled):
"""Default implementation of :class:`.Compiled`.
@@ -2688,6 +2729,9 @@ class SQLCompiler(Compiled):
def _like_percent_literal(self):
return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
+ def visit_ilike_case_insensitive_operand(self, element, **kw):
+ return f"lower({element.element._compiler_dispatch(self, **kw)})"
+
def visit_contains_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
@@ -2700,6 +2744,24 @@ class SQLCompiler(Compiled):
binary.right = percent.concat(binary.right).concat(percent)
return self.visit_not_like_op_binary(binary, operator, **kw)
+ def visit_icontains_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = percent.concat(
+ ilike_case_insensitive(binary.right)
+ ).concat(percent)
+ return self.visit_ilike_op_binary(binary, operator, **kw)
+
+ def visit_not_icontains_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = percent.concat(
+ ilike_case_insensitive(binary.right)
+ ).concat(percent)
+ return self.visit_not_ilike_op_binary(binary, operator, **kw)
+
def visit_startswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
@@ -2712,6 +2774,20 @@ class SQLCompiler(Compiled):
binary.right = percent._rconcat(binary.right)
return self.visit_not_like_op_binary(binary, operator, **kw)
+ def visit_istartswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
+ return self.visit_ilike_op_binary(binary, operator, **kw)
+
+ def visit_not_istartswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
+ return self.visit_not_ilike_op_binary(binary, operator, **kw)
+
def visit_endswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
@@ -2724,10 +2800,23 @@ class SQLCompiler(Compiled):
binary.right = percent.concat(binary.right)
return self.visit_not_like_op_binary(binary, operator, **kw)
+ def visit_iendswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = percent.concat(ilike_case_insensitive(binary.right))
+ return self.visit_ilike_op_binary(binary, operator, **kw)
+
+ def visit_not_iendswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = percent.concat(ilike_case_insensitive(binary.right))
+ return self.visit_not_ilike_op_binary(binary, operator, **kw)
+
def visit_like_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- # TODO: use ternary here, not "and"/ "or"
return "%s LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw),
@@ -2749,26 +2838,22 @@ class SQLCompiler(Compiled):
)
def visit_ilike_op_binary(self, binary, operator, **kw):
- escape = binary.modifiers.get("escape", None)
- return "lower(%s) LIKE lower(%s)" % (
- binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw),
- ) + (
- " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape
- else ""
- )
+ if operator is operators.ilike_op:
+ binary = binary._clone()
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = ilike_case_insensitive(binary.right)
+ # else we assume ilower() has been applied
+
+ return self.visit_like_op_binary(binary, operator, **kw)
def visit_not_ilike_op_binary(self, binary, operator, **kw):
- escape = binary.modifiers.get("escape", None)
- return "lower(%s) NOT LIKE lower(%s)" % (
- binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw),
- ) + (
- " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape
- else ""
- )
+ if operator is operators.not_ilike_op:
+ binary = binary._clone()
+ binary.left = ilike_case_insensitive(binary.left)
+ binary.right = ilike_case_insensitive(binary.right)
+ # else we assume ilower() has been applied
+
+ return self.visit_not_like_op_binary(binary, operator, **kw)
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)