1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
|
from __future__ import annotations
import contextlib
import re
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy import __version__
from sqlalchemy import inspect
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import visitors
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.visitors import traverse
if TYPE_CHECKING:
from sqlalchemy import Index
from sqlalchemy import Table
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine import Transaction
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.base import ColumnCollection
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql.selectable import TableClause
_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
return value
_vers = tuple(
[_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
)
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
sqla_14_26 = _vers >= (1, 4, 26)
sqla_2 = _vers >= (2,)
sqlalchemy_version = __version__
if sqla_14:
# when future engine merges, this can be again based on version string
from sqlalchemy.engine import Connection as legacy_connection
sqla_1x = not hasattr(legacy_connection, "commit")
else:
sqla_1x = True
try:
from sqlalchemy import Computed # noqa
except ImportError:
Computed = type(None) # type: ignore
has_computed = False
has_computed_reflection = False
else:
has_computed = True
has_computed_reflection = _vers >= (1, 3, 16)
try:
from sqlalchemy import Identity # noqa
except ImportError:
Identity = type(None) # type: ignore
has_identity = False
else:
# attributes common to Indentity and Sequence
_identity_options_attrs = (
"start",
"increment",
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
"order",
)
# attributes of Indentity
_identity_attrs = _identity_options_attrs + ("on_null",)
has_identity = True
AUTOINCREMENT_DEFAULT = "auto"
@contextlib.contextmanager
def _ensure_scope_for_ddl(
connection: Optional[Connection],
) -> Iterator[None]:
try:
in_transaction = connection.in_transaction # type: ignore[union-attr]
except AttributeError:
# catch for MockConnection, None
in_transaction = None
pass
# yield outside the catch
if in_transaction is None:
yield
else:
if not in_transaction():
assert connection is not None
with connection.begin():
yield
else:
yield
def url_render_as_string(url, hide_password=True):
if sqla_14:
return url.render_as_string(hide_password=hide_password)
else:
return url.__to_string__(hide_password=hide_password)
def _safe_begin_connection_transaction(
connection: Connection,
) -> Transaction:
transaction = _get_connection_transaction(connection)
if transaction:
return transaction
else:
return connection.begin()
def _safe_commit_connection_transaction(
connection: Connection,
) -> None:
transaction = _get_connection_transaction(connection)
if transaction:
transaction.commit()
def _safe_rollback_connection_transaction(
connection: Connection,
) -> None:
transaction = _get_connection_transaction(connection)
if transaction:
transaction.rollback()
def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
try:
in_transaction = connection.in_transaction # type: ignore
except AttributeError:
# catch for MockConnection
return False
else:
return in_transaction()
def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
return idx.expressions # type: ignore
def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw) # type: ignore[union-attr]
else:
return schema_item.copy(**kw) # type: ignore[union-attr]
def _get_connection_transaction(
connection: Connection,
) -> Optional[Transaction]:
if sqla_14:
return connection.get_transaction()
else:
r = connection._root # type: ignore[attr-defined]
return r._Connection__transaction
def _create_url(*arg, **kw) -> url.URL:
if hasattr(url.URL, "create"):
return url.URL.create(*arg, **kw)
else:
return url.URL(*arg, **kw)
def _connectable_has_table(
connectable: Connection, tablename: str, schemaname: Union[str, None]
) -> bool:
if sqla_14:
return inspect(connectable).has_table(tablename, schemaname)
else:
return connectable.dialect.has_table(
connectable, tablename, schemaname
)
def _exec_on_inspector(inspector, statement, **params):
if sqla_14:
with inspector._operation_context() as conn:
return conn.execute(statement, params)
else:
return inspector.bind.execute(statement, params)
def _nullability_might_be_unset(metadata_column):
if not sqla_14:
return metadata_column.nullable
else:
from sqlalchemy.sql import schema
return (
metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
)
def _server_default_is_computed(*server_default) -> bool:
if not has_computed:
return False
else:
return any(isinstance(sd, Computed) for sd in server_default)
def _server_default_is_identity(*server_default) -> bool:
if not sqla_14:
return False
else:
return any(isinstance(sd, Identity) for sd in server_default)
def _table_for_constraint(constraint: Constraint) -> Table:
if isinstance(constraint, ForeignKeyConstraint):
table = constraint.parent
assert table is not None
return table # type: ignore[return-value]
else:
return constraint.table
def _columns_for_constraint(constraint):
if isinstance(constraint, ForeignKeyConstraint):
return [fk.parent for fk in constraint.elements]
elif isinstance(constraint, CheckConstraint):
return _find_columns(constraint.sqltext)
else:
return list(constraint.columns)
def _reflect_table(
inspector: Inspector, table: Table, include_cols: None
) -> None:
if sqla_14:
return inspector.reflect_table(table, None)
else:
return inspector.reflecttable( # type: ignore[attr-defined]
table, None
)
def _resolve_for_variant(type_, dialect):
if _type_has_variants(type_):
base_type, mapping = _get_variant_mapping(type_)
return mapping.get(dialect.name, base_type)
else:
return type_
if hasattr(sqltypes.TypeEngine, "_variant_mapping"):
def _type_has_variants(type_):
return bool(type_._variant_mapping)
def _get_variant_mapping(type_):
return type_, type_._variant_mapping
else:
def _type_has_variants(type_):
return type(type_) is sqltypes.Variant
def _get_variant_mapping(type_):
return type_.impl, type_.mapping
def _fk_spec(constraint):
source_columns = [
constraint.columns[key].name for key in constraint.column_keys
]
source_table = constraint.parent.name
source_schema = constraint.parent.schema
target_schema = constraint.elements[0].column.table.schema
target_table = constraint.elements[0].column.table.name
target_columns = [element.column.name for element in constraint.elements]
ondelete = constraint.ondelete
onupdate = constraint.onupdate
deferrable = constraint.deferrable
initially = constraint.initially
return (
source_schema,
source_table,
source_columns,
target_schema,
target_table,
target_columns,
onupdate,
ondelete,
deferrable,
initially,
)
def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
assert constraint.parent is not None
return tablekey == constraint.parent.key
def _is_type_bound(constraint: Constraint) -> bool:
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
return constraint._type_bound # type: ignore[attr-defined]
def _find_columns(clause):
"""locate Column objects within the given expression."""
cols = set()
traverse(clause, {}, {"column": cols.add})
return cols
def _remove_column_from_collection(
collection: ColumnCollection, column: Union[Column, ColumnClause]
) -> None:
"""remove a column from a ColumnCollection."""
# workaround for older SQLAlchemy, remove the
# same object that's present
assert column.key is not None
to_remove = collection[column.key]
# SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
# (renamed from ImmutableColumnCollection)
if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
collection._parent.remove(to_remove)
else:
collection.remove(to_remove)
def _textual_index_column(
table: Table, text_: Union[str, TextClause, ColumnElement]
) -> Union[ColumnElement, Column]:
"""a workaround for the Index construct's severe lack of flexibility"""
if isinstance(text_, str):
c = Column(text_, sqltypes.NULLTYPE)
table.append_column(c)
return c
elif isinstance(text_, TextClause):
return _textual_index_element(table, text_)
elif isinstance(text_, sql.ColumnElement):
return _copy_expression(text_, table)
else:
raise ValueError("String or text() construct expected")
def _copy_expression(expression: _CE, target_table: Table) -> _CE:
def replace(col):
if (
isinstance(col, Column)
and col.table is not None
and col.table is not target_table
):
if col.name in target_table.c:
return target_table.c[col.name]
else:
c = _copy(col)
target_table.append_column(c)
return c
else:
return None
return visitors.replacement_traverse( # type: ignore[call-overload]
expression, {}, replace
)
class _textual_index_element(sql.ColumnElement):
"""Wrap around a sqlalchemy text() construct in such a way that
we appear like a column-oriented SQL expression to an Index
construct.
The issue here is that currently the Postgresql dialect, the biggest
recipient of functional indexes, keys all the index expressions to
the corresponding column expressions when rendering CREATE INDEX,
so the Index we create here needs to have a .columns collection that
is the same length as the .expressions collection. Ultimately
SQLAlchemy should support text() expressions in indexes.
See SQLAlchemy issue 3174.
"""
__visit_name__ = "_textual_idx_element"
def __init__(self, table: Table, text: TextClause) -> None:
self.table = table
self.text = text
self.key = text.text
self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
table.append_column(self.fake_column)
def get_children(self):
return [self.fake_column]
@compiles(_textual_index_element)
def _render_textual_index_column(
element: _textual_index_element, compiler: SQLCompiler, **kw
) -> str:
return compiler.process(element.text, **kw)
class _literal_bindparam(BindParameter):
pass
@compiles(_literal_bindparam)
def _render_literal_bindparam(
element: _literal_bindparam, compiler: SQLCompiler, **kw
) -> str:
return compiler.render_literal_bindparam(element, **kw)
def _get_index_expressions(idx):
return list(idx.expressions)
def _get_index_column_names(idx):
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
def _column_kwargs(col: Column) -> Mapping:
if sqla_13:
return col.kwargs
else:
return {}
def _get_constraint_final_name(
constraint: Union[Index, Constraint], dialect: Optional[Dialect]
) -> Optional[str]:
if constraint.name is None:
return None
assert dialect is not None
if sqla_14:
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
# SQLAlchemy API to return what would be the final compiled form of
# the name for this dialect.
return dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
else:
# prior to SQLAlchemy 1.4, work around quoting logic to get at the
# final compiled name without quotes.
if hasattr(constraint.name, "quote"):
# might be quoted_name, might be truncated_name, keep it the
# same
quoted_name_cls: type = type(constraint.name)
else:
quoted_name_cls = quoted_name
new_name = quoted_name_cls(str(constraint.name), quote=False)
constraint = constraint.__class__(name=new_name)
if isinstance(constraint, schema.Index):
# name should not be quoted.
d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
return d._prepared_index_name( # type: ignore[attr-defined]
constraint
)
else:
# name should not be quoted.
return dialect.identifier_preparer.format_constraint(constraint)
def _constraint_is_named(
constraint: Union[Constraint, Index], dialect: Optional[Dialect]
) -> bool:
if sqla_14:
if constraint.name is None:
return False
assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
return name is not None
else:
return constraint.name is not None
def _is_mariadb(mysql_dialect: Dialect) -> bool:
if sqla_14:
return mysql_dialect.is_mariadb # type: ignore[attr-defined]
else:
return bool(
mysql_dialect.server_version_info
and mysql_dialect._is_mariadb # type: ignore[attr-defined]
)
def _mariadb_normalized_version_info(mysql_dialect):
return mysql_dialect._mariadb_normalized_version_info
def _insert_inline(table: Union[TableClause, Table]) -> Insert:
if sqla_14:
return table.insert().inline()
else:
return table.insert(inline=True) # type: ignore[call-arg]
if sqla_14:
from sqlalchemy import create_mock_engine
from sqlalchemy import select as _select
else:
from sqlalchemy import create_engine
def create_mock_engine(url, executor, **kw): # type: ignore[misc]
return create_engine(
"postgresql://", strategy="mock", executor=executor
)
def _select(*columns, **kw) -> Select: # type: ignore[no-redef]
return sql.select(list(columns), **kw) # type: ignore[call-overload]
def is_expression_index(index: Index) -> bool:
expr: Any
for expr in index.expressions:
while isinstance(expr, UnaryExpression):
expr = expr.element
if not isinstance(expr, ColumnClause) or expr.is_literal:
return True
return False
|