summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py75
1 files changed, 69 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 66a294d10..596ca986f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -37,6 +37,7 @@ import typing
from typing import Any
from typing import Callable
from typing import cast
+from typing import ClassVar
from typing import Dict
from typing import FrozenSet
from typing import Iterable
@@ -46,6 +47,7 @@ from typing import MutableMapping
from typing import NamedTuple
from typing import NoReturn
from typing import Optional
+from typing import Pattern
from typing import Sequence
from typing import Set
from typing import Tuple
@@ -238,9 +240,6 @@ BIND_TEMPLATES = {
}
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
-
OPERATORS = {
# binary
operators.and_: " AND ",
@@ -714,6 +713,14 @@ class Compiled:
self._gen_time = perf_counter()
+ def __init_subclass__(cls) -> None:
+ cls._init_compiler_cls()
+ return super().__init_subclass__()
+
+ @classmethod
+ def _init_compiler_cls(cls):
+ pass
+
def _execute_on_connection(
self, connection, distilled_params, execution_options
):
@@ -866,6 +873,52 @@ class SQLCompiler(Compiled):
extract_map = EXTRACT_MAP
+ bindname_escape_characters: ClassVar[
+ Mapping[str, str]
+ ] = util.immutabledict(
+ {
+ "%": "P",
+ "(": "A",
+ ")": "Z",
+ ":": "C",
+ ".": "_",
+ "[": "_",
+ "]": "_",
+ " ": "_",
+ }
+ )
+ """A mapping (e.g. dict or similar) containing a lookup of
+ characters keyed to replacement characters which will be applied to all
+ 'bind names' used in SQL statements as a form of 'escaping'; the given
+ characters are replaced entirely with the 'replacement' character when
+ rendered in the SQL statement, and a similar translation is performed
+ on the incoming names used in parameter dictionaries passed to methods
+ like :meth:`_engine.Connection.execute`.
+
+ This allows bound parameter names used in :func:`_sql.bindparam` and
+ other constructs to have any arbitrary characters present without any
+ concern for characters that aren't allowed at all on the target database.
+
+ Third party dialects can establish their own dictionary here to replace the
+ default mapping, which will ensure that the particular characters in the
+ mapping will never appear in a bound parameter name.
+
+ The dictionary is evaluated at **class creation time**, so cannot be
+ modified at runtime; it must be present on the class when the class
+ is first declared.
+
+ Note that for dialects that have additional bound parameter rules such
+ as additional restrictions on leading characters, the
+ :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
+ See the cx_Oracle compiler for an example of this.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ _bind_translate_re: ClassVar[Pattern[str]]
+ _bind_translate_chars: ClassVar[Mapping[str, str]]
+
is_sql = True
compound_keywords = COMPOUND_KEYWORDS
@@ -1108,6 +1161,16 @@ class SQLCompiler(Compiled):
f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
)
+ @classmethod
+ def _init_compiler_cls(cls):
+ cls._init_bind_translate()
+
+ @classmethod
+ def _init_bind_translate(cls):
+ reg = re.escape("".join(cls.bindname_escape_characters))
+ cls._bind_translate_re = re.compile(f"[{reg}]")
+ cls._bind_translate_chars = cls.bindname_escape_characters
+
def __init__(
self,
dialect: Dialect,
@@ -3591,12 +3654,12 @@ class SQLCompiler(Compiled):
if not escaped_from:
- if _BIND_TRANSLATE_RE.search(name):
+ if self._bind_translate_re.search(name):
# not quite the translate use case as we want to
# also get a quick boolean if we even found
# unusual characters in the name
- new_name = _BIND_TRANSLATE_RE.sub(
- lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
+ new_name = self._bind_translate_re.sub(
+ lambda m: self._bind_translate_chars[m.group(0)],
name,
)
escaped_from = name