diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/sqlite/pysqlite.py')
-rw-r--r-- | lib/sqlalchemy/dialects/sqlite/pysqlite.py | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 4475ccae7..c04a3601d 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -637,3 +637,110 @@ class SQLiteDialect_pysqlite(SQLiteDialect): dialect = SQLiteDialect_pysqlite + + +class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): + """numeric dialect for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric" + driver = "pysqlite_numeric" + + _first_bind = ":1" + _not_in_statement_regexp = None + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric") + super().__init__(*arg, **kw) + + def create_connect_args(self, url): + arg, opts = super().create_connect_args(url) + opts["factory"] = self._fix_sqlite_issue_99953() + return arg, opts + + def _fix_sqlite_issue_99953(self): + import sqlite3 + + first_bind = self._first_bind + if self._not_in_statement_regexp: + nis = self._not_in_statement_regexp + + def _test_sql(sql): + m = nis.search(sql) + assert not m, f"Found {nis.pattern!r} in {sql!r}" + + else: + + def _test_sql(sql): + pass + + def _numeric_param_as_dict(parameters): + if parameters: + assert isinstance(parameters, tuple) + return { + str(idx): value for idx, value in enumerate(parameters, 1) + } + else: + return () + + class SQLiteFix99953Cursor(sqlite3.Cursor): + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + class SQLiteFix99953Connection(sqlite3.Connection): + def cursor(self, factory=None): + if factory is None: + factory = SQLiteFix99953Cursor + return super().cursor(factory=factory) + + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + return SQLiteFix99953Connection + + +class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric): + """numeric dialect that uses $ for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric_dollar" + driver = "pysqlite_dollar" + + _first_bind = "$1" + _not_in_statement_regexp = re.compile(r"[^\d]:\d+") + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric_dollar") + super().__init__(*arg, **kw) |