diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-12-16 19:37:51 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-12-16 19:37:51 +0000 |
commit | e84cc158c469f17c90f2e058ed72595bc3be5cdb (patch) | |
tree | 07449f8ce0583844bd9a18dd8083b32ae7113490 /lib/sqlalchemy/sql/compiler.py | |
parent | bd5a4611c34d25cf21607544c01ce7fcb886e0a9 (diff) | |
parent | d7107641c309e0b7db9b0876ac048dbb38316ba6 (diff) | |
download | sqlalchemy-e84cc158c469f17c90f2e058ed72595bc3be5cdb.tar.gz |
Merge "make bind escape lookup extensible" into main
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 75 |
1 files changed, 69 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 66a294d10..596ca986f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -37,6 +37,7 @@ import typing from typing import Any from typing import Callable from typing import cast +from typing import ClassVar from typing import Dict from typing import FrozenSet from typing import Iterable @@ -46,6 +47,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import NoReturn from typing import Optional +from typing import Pattern from typing import Sequence from typing import Set from typing import Tuple @@ -238,9 +240,6 @@ BIND_TEMPLATES = { } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) - OPERATORS = { # binary operators.and_: " AND ", @@ -714,6 +713,14 @@ class Compiled: self._gen_time = perf_counter() + def __init_subclass__(cls) -> None: + cls._init_compiler_cls() + return super().__init_subclass__() + + @classmethod + def _init_compiler_cls(cls): + pass + def _execute_on_connection( self, connection, distilled_params, execution_options ): @@ -866,6 +873,52 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + bindname_escape_characters: ClassVar[ + Mapping[str, str] + ] = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } + ) + """A mapping (e.g. dict or similar) containing a lookup of + characters keyed to replacement characters which will be applied to all + 'bind names' used in SQL statements as a form of 'escaping'; the given + characters are replaced entirely with the 'replacement' character when + rendered in the SQL statement, and a similar translation is performed + on the incoming names used in parameter dictionaries passed to methods + like :meth:`_engine.Connection.execute`. + + This allows bound parameter names used in :func:`_sql.bindparam` and + other constructs to have any arbitrary characters present without any + concern for characters that aren't allowed at all on the target database. + + Third party dialects can establish their own dictionary here to replace the + default mapping, which will ensure that the particular characters in the + mapping will never appear in a bound parameter name. + + The dictionary is evaluated at **class creation time**, so cannot be + modified at runtime; it must be present on the class when the class + is first declared. + + Note that for dialects that have additional bound parameter rules such + as additional restrictions on leading characters, the + :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented. + See the cx_Oracle compiler for an example of this. + + .. versionadded:: 2.0.0b5 + + """ + + _bind_translate_re: ClassVar[Pattern[str]] + _bind_translate_chars: ClassVar[Mapping[str, str]] + is_sql = True compound_keywords = COMPOUND_KEYWORDS @@ -1108,6 +1161,16 @@ class SQLCompiler(Compiled): f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" ) + @classmethod + def _init_compiler_cls(cls): + cls._init_bind_translate() + + @classmethod + def _init_bind_translate(cls): + reg = re.escape("".join(cls.bindname_escape_characters)) + cls._bind_translate_re = re.compile(f"[{reg}]") + cls._bind_translate_chars = cls.bindname_escape_characters + def __init__( self, dialect: Dialect, @@ -3591,12 +3654,12 @@ class SQLCompiler(Compiled): if not escaped_from: - if _BIND_TRANSLATE_RE.search(name): + if self._bind_translate_re.search(name): # not quite the translate use case as we want to # also get a quick boolean if we even found # unusual characters in the name - new_name = _BIND_TRANSLATE_RE.sub( - lambda m: _BIND_TRANSLATE_CHARS[m.group(0)], + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], name, ) escaped_from = name |