summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py2
-rw-r--r--lib/sqlalchemy/sql/base.py8
-rw-r--r--lib/sqlalchemy/sql/coercions.py8
-rw-r--r--lib/sqlalchemy/sql/compiler.py243
-rw-r--r--lib/sqlalchemy/sql/crud.py18
-rw-r--r--lib/sqlalchemy/sql/ddl.py12
-rw-r--r--lib/sqlalchemy/sql/dml.py4
-rw-r--r--lib/sqlalchemy/sql/elements.py18
-rw-r--r--lib/sqlalchemy/sql/functions.py12
-rw-r--r--lib/sqlalchemy/sql/lambdas.py4
-rw-r--r--lib/sqlalchemy/sql/operators.py8
-rw-r--r--lib/sqlalchemy/sql/schema.py30
-rw-r--r--lib/sqlalchemy/sql/selectable.py46
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py42
-rw-r--r--lib/sqlalchemy/sql/traversals.py4
-rw-r--r--lib/sqlalchemy/sql/type_api.py10
-rw-r--r--lib/sqlalchemy/sql/util.py26
17 files changed, 226 insertions, 269 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index 8b8f6b010..7c5281bee 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -1127,7 +1127,7 @@ def label(
name: str,
element: _ColumnExpressionArgument[_T],
type_: Optional[_TypeEngineArgument[_T]] = None,
-) -> "Label[_T]":
+) -> Label[_T]:
"""Return a :class:`Label` object for the
given :class:`_expression.ColumnElement`.
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 34b295113..c81891169 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -291,16 +291,14 @@ def _cloned_intersection(a, b):
"""
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(
- elem for elem in a if all_overlap.intersection(elem._cloned_set)
- )
+ return {elem for elem in a if all_overlap.intersection(elem._cloned_set)}
def _cloned_difference(a, b):
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(
+ return {
elem for elem in a if not all_overlap.intersection(elem._cloned_set)
- )
+ }
class _DialectArgView(MutableMapping[str, Any]):
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 8074bcf8b..f48a3ccb0 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -782,7 +782,7 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
else:
advice = None
- return super(ExpressionElementImpl, self)._raise_for_expected(
+ return super()._raise_for_expected(
element, argname=argname, resolved=resolved, advice=advice, **kw
)
@@ -1096,7 +1096,7 @@ class LabeledColumnExprImpl(ExpressionElementImpl):
if isinstance(resolved, roles.ExpressionElementRole):
return resolved.label(None)
else:
- new = super(LabeledColumnExprImpl, self)._implicit_coercions(
+ new = super()._implicit_coercions(
element, resolved, argname=argname, **kw
)
if isinstance(new, roles.ExpressionElementRole):
@@ -1123,7 +1123,7 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
f"{', '.join(repr(e) for e in element)})?"
)
- return super(ColumnsClauseImpl, self)._raise_for_expected(
+ return super()._raise_for_expected(
element, argname=argname, resolved=resolved, advice=advice, **kw
)
@@ -1370,7 +1370,7 @@ class CompoundElementImpl(_NoTextCoercion, RoleImpl):
)
else:
advice = None
- return super(CompoundElementImpl, self)._raise_for_expected(
+ return super()._raise_for_expected(
element, argname=argname, resolved=resolved, advice=advice, **kw
)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 9a00afc91..17aafddad 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -115,104 +115,102 @@ if typing.TYPE_CHECKING:
_FromHintsType = Dict["FromClause", str]
-RESERVED_WORDS = set(
- [
- "all",
- "analyse",
- "analyze",
- "and",
- "any",
- "array",
- "as",
- "asc",
- "asymmetric",
- "authorization",
- "between",
- "binary",
- "both",
- "case",
- "cast",
- "check",
- "collate",
- "column",
- "constraint",
- "create",
- "cross",
- "current_date",
- "current_role",
- "current_time",
- "current_timestamp",
- "current_user",
- "default",
- "deferrable",
- "desc",
- "distinct",
- "do",
- "else",
- "end",
- "except",
- "false",
- "for",
- "foreign",
- "freeze",
- "from",
- "full",
- "grant",
- "group",
- "having",
- "ilike",
- "in",
- "initially",
- "inner",
- "intersect",
- "into",
- "is",
- "isnull",
- "join",
- "leading",
- "left",
- "like",
- "limit",
- "localtime",
- "localtimestamp",
- "natural",
- "new",
- "not",
- "notnull",
- "null",
- "off",
- "offset",
- "old",
- "on",
- "only",
- "or",
- "order",
- "outer",
- "overlaps",
- "placing",
- "primary",
- "references",
- "right",
- "select",
- "session_user",
- "set",
- "similar",
- "some",
- "symmetric",
- "table",
- "then",
- "to",
- "trailing",
- "true",
- "union",
- "unique",
- "user",
- "using",
- "verbose",
- "when",
- "where",
- ]
-)
+RESERVED_WORDS = {
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "authorization",
+ "between",
+ "binary",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "for",
+ "foreign",
+ "freeze",
+ "from",
+ "full",
+ "grant",
+ "group",
+ "having",
+ "ilike",
+ "in",
+ "initially",
+ "inner",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "leading",
+ "left",
+ "like",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "natural",
+ "new",
+ "not",
+ "notnull",
+ "null",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "outer",
+ "overlaps",
+ "placing",
+ "primary",
+ "references",
+ "right",
+ "select",
+ "session_user",
+ "set",
+ "similar",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "verbose",
+ "when",
+ "where",
+}
LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
@@ -505,8 +503,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
"between each element to resolve."
)
froms_str = ", ".join(
- '"{elem}"'.format(elem=self.froms[from_])
- for from_ in froms
+ f'"{self.froms[from_]}"' for from_ in froms
)
message = template.format(
froms=froms_str, start=self.froms[start_with]
@@ -1259,11 +1256,8 @@ class SQLCompiler(Compiled):
# mypy is not able to see the two value types as the above Union,
# it just sees "object". don't know how to resolve
- return dict(
- (
- key,
- value,
- ) # type: ignore
+ return {
+ key: value # type: ignore
for key, value in (
(
self.bind_names[bindparam],
@@ -1277,7 +1271,7 @@ class SQLCompiler(Compiled):
for bindparam in self.bind_names
)
if value is not None
- )
+ }
def is_subquery(self):
return len(self.stack) > 1
@@ -4147,17 +4141,12 @@ class SQLCompiler(Compiled):
def _setup_select_hints(
self, select: Select[Any]
) -> Tuple[str, _FromHintsType]:
- byfrom = dict(
- [
- (
- from_,
- hinttext
- % {"name": from_._compiler_dispatch(self, ashint=True)},
- )
- for (from_, dialect), hinttext in select._hints.items()
- if dialect in ("*", self.dialect.name)
- ]
- )
+ byfrom = {
+ from_: hinttext
+ % {"name": from_._compiler_dispatch(self, ashint=True)}
+ for (from_, dialect), hinttext in select._hints.items()
+ if dialect in ("*", self.dialect.name)
+ }
hint_text = self.get_select_hint_text(byfrom)
return hint_text, byfrom
@@ -4583,13 +4572,11 @@ class SQLCompiler(Compiled):
)
def _setup_crud_hints(self, stmt, table_text):
- dialect_hints = dict(
- [
- (table, hint_text)
- for (table, dialect), hint_text in stmt._hints.items()
- if dialect in ("*", self.dialect.name)
- ]
- )
+ dialect_hints = {
+ table: hint_text
+ for (table, dialect), hint_text in stmt._hints.items()
+ if dialect in ("*", self.dialect.name)
+ }
if stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
table_text, stmt.table, dialect_hints[stmt.table], True
@@ -5318,9 +5305,7 @@ class StrSQLCompiler(SQLCompiler):
if not isinstance(compiler, StrSQLCompiler):
return compiler.process(element)
- return super(StrSQLCompiler, self).visit_unsupported_compilation(
- element, err
- )
+ return super().visit_unsupported_compilation(element, err)
def visit_getitem_binary(self, binary, operator, **kw):
return "%s[%s]" % (
@@ -6603,14 +6588,14 @@ class IdentifierPreparer:
@util.memoized_property
def _r_identifiers(self):
- initial, final, escaped_final = [
+ initial, final, escaped_final = (
re.escape(s)
for s in (
self.initial_quote,
self.final_quote,
self._escape_identifier(self.final_quote),
)
- ]
+ )
r = re.compile(
r"(?:"
r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 31d127c2c..017ff7baa 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -227,15 +227,15 @@ def _get_crud_params(
parameters = {}
elif stmt_parameter_tuples:
assert spd is not None
- parameters = dict(
- (_column_as_key(key), REQUIRED)
+ parameters = {
+ _column_as_key(key): REQUIRED
for key in compiler.column_keys
if key not in spd
- )
+ }
else:
- parameters = dict(
- (_column_as_key(key), REQUIRED) for key in compiler.column_keys
- )
+ parameters = {
+ _column_as_key(key): REQUIRED for key in compiler.column_keys
+ }
# create a list of column assignment clauses as tuples
values: List[_CrudParamElement] = []
@@ -1278,10 +1278,10 @@ def _get_update_multitable_params(
values,
kw,
):
- normalized_params = dict(
- (coercions.expect(roles.DMLColumnRole, c), param)
+ normalized_params = {
+ coercions.expect(roles.DMLColumnRole, c): param
for c, param in stmt_parameter_tuples
- )
+ }
include_table = compile_state.include_table_with_column_exprs
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index fa0c25b1d..ecdc2eb63 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -176,7 +176,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement):
"""
_ddl_if: Optional[DDLIf] = None
- target: Optional["SchemaItem"] = None
+ target: Optional[SchemaItem] = None
def _execute_on_connection(
self, connection, distilled_params, execution_options
@@ -1179,12 +1179,10 @@ class SchemaDropper(InvokeDropDDLBase):
def sort_tables(
- tables: Iterable["Table"],
- skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None,
- extra_dependencies: Optional[
- typing_Sequence[Tuple["Table", "Table"]]
- ] = None,
-) -> List["Table"]:
+ tables: Iterable[Table],
+ skip_fn: Optional[Callable[[ForeignKeyConstraint], bool]] = None,
+ extra_dependencies: Optional[typing_Sequence[Tuple[Table, Table]]] = None,
+) -> List[Table]:
"""Sort a collection of :class:`_schema.Table` objects based on
dependency.
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 2d3e3598b..c279e344b 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -1179,7 +1179,7 @@ class Insert(ValuesBase):
)
def __init__(self, table: _DMLTableArgument):
- super(Insert, self).__init__(table)
+ super().__init__(table)
@_generative
def inline(self: SelfInsert) -> SelfInsert:
@@ -1498,7 +1498,7 @@ class Update(DMLWhereBase, ValuesBase):
)
def __init__(self, table: _DMLTableArgument):
- super(Update, self).__init__(table)
+ super().__init__(table)
@_generative
def ordered_values(
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 044bdf585..d9a1a9358 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -3035,7 +3035,7 @@ class BooleanClauseList(ExpressionClauseList[bool]):
if not self.clauses:
return self
else:
- return super(BooleanClauseList, self).self_group(against=against)
+ return super().self_group(against=against)
and_ = BooleanClauseList.and_
@@ -3082,7 +3082,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
]
self.type = sqltypes.TupleType(*[arg.type for arg in init_clauses])
- super(Tuple, self).__init__(*init_clauses)
+ super().__init__(*init_clauses)
@property
def _select_iterable(self) -> _SelectIterable:
@@ -3753,8 +3753,8 @@ class BinaryExpression(OperatorExpression[_T]):
if typing.TYPE_CHECKING:
def __invert__(
- self: "BinaryExpression[_T]",
- ) -> "BinaryExpression[_T]":
+ self: BinaryExpression[_T],
+ ) -> BinaryExpression[_T]:
...
@util.ro_non_memoized_property
@@ -3772,7 +3772,7 @@ class BinaryExpression(OperatorExpression[_T]):
modifiers=self.modifiers,
)
else:
- return super(BinaryExpression, self)._negate()
+ return super()._negate()
class Slice(ColumnElement[Any]):
@@ -4617,7 +4617,7 @@ class ColumnClause(
if self.table is not None:
return self.table.entity_namespace
else:
- return super(ColumnClause, self).entity_namespace
+ return super().entity_namespace
def _clone(self, detect_subquery_cols=False, **kw):
if (
@@ -4630,7 +4630,7 @@ class ColumnClause(
new = table.c.corresponding_column(self)
return new
- return super(ColumnClause, self)._clone(**kw)
+ return super()._clone(**kw)
@HasMemoized_ro_memoized_attribute
def _from_objects(self) -> List[FromClause]:
@@ -4993,7 +4993,7 @@ class AnnotatedColumnElement(Annotated):
self.__dict__.pop(attr)
def _with_annotations(self, values):
- clone = super(AnnotatedColumnElement, self)._with_annotations(values)
+ clone = super()._with_annotations(values)
clone.__dict__.pop("comparator", None)
return clone
@@ -5032,7 +5032,7 @@ class _truncated_label(quoted_name):
def __new__(cls, value: str, quote: Optional[bool] = None) -> Any:
quote = getattr(value, "quote", quote)
# return super(_truncated_label, cls).__new__(cls, value, quote, True)
- return super(_truncated_label, cls).__new__(cls, value, quote)
+ return super().__new__(cls, value, quote)
def __reduce__(self) -> Any:
return self.__class__, (str(self), self.quote)
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index fad7c28eb..5ed89bc82 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -167,9 +167,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
@property
def _proxy_key(self):
- return super(FunctionElement, self)._proxy_key or getattr(
- self, "name", None
- )
+ return super()._proxy_key or getattr(self, "name", None)
def _execute_on_connection(
self,
@@ -660,7 +658,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
):
return Grouping(self)
else:
- return super(FunctionElement, self).self_group(against=against)
+ return super().self_group(against=against)
@property
def entity_namespace(self):
@@ -1198,7 +1196,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
]
kwargs.setdefault("type_", _type_from_args(fn_args))
kwargs["_parsed_args"] = fn_args
- super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs)
+ super().__init__(*fn_args, **kwargs)
class coalesce(ReturnTypeFromArgs[_T]):
@@ -1304,7 +1302,7 @@ class count(GenericFunction[int]):
def __init__(self, expression=None, **kwargs):
if expression is None:
expression = literal_column("*")
- super(count, self).__init__(expression, **kwargs)
+ super().__init__(expression, **kwargs)
class current_date(AnsiFunction[datetime.date]):
@@ -1411,7 +1409,7 @@ class array_agg(GenericFunction[_T]):
type_from_args, dimensions=1
)
kwargs["_parsed_args"] = fn_args
- super(array_agg, self).__init__(*fn_args, **kwargs)
+ super().__init__(*fn_args, **kwargs)
class OrderedSetAgg(GenericFunction[_T]):
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index bbfaf47e1..26e3a21bb 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -439,7 +439,7 @@ class DeferredLambdaElement(LambdaElement):
lambda_args: Tuple[Any, ...] = (),
):
self.lambda_args = lambda_args
- super(DeferredLambdaElement, self).__init__(fn, role, opts)
+ super().__init__(fn, role, opts)
def _invoke_user_fn(self, fn, *arg):
return fn(*self.lambda_args)
@@ -483,7 +483,7 @@ class DeferredLambdaElement(LambdaElement):
def _copy_internals(
self, clone=_clone, deferred_copy_internals=None, **kw
):
- super(DeferredLambdaElement, self)._copy_internals(
+ super()._copy_internals(
clone=clone,
deferred_copy_internals=deferred_copy_internals, # **kw
opts=kw,
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 55c275741..2d1f9caa1 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -66,11 +66,11 @@ class OperatorType(Protocol):
def __call__(
self,
- left: "Operators",
+ left: Operators,
right: Optional[Any] = None,
*other: Any,
**kwargs: Any,
- ) -> "Operators":
+ ) -> Operators:
...
@@ -184,7 +184,7 @@ class Operators:
precedence: int = 0,
is_comparison: bool = False,
return_type: Optional[
- Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"]
+ Union[Type[TypeEngine[Any]], TypeEngine[Any]]
] = None,
python_impl: Optional[Callable[..., Any]] = None,
) -> Callable[[Any], Operators]:
@@ -397,7 +397,7 @@ class custom_op(OperatorType, Generic[_T]):
precedence: int = 0,
is_comparison: bool = False,
return_type: Optional[
- Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
+ Union[Type[TypeEngine[_T]], TypeEngine[_T]]
] = None,
natural_self_precedent: bool = False,
eager_grouping: bool = False,
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index cd10d0c4a..f76fc447c 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -920,11 +920,11 @@ class Table(
:attr:`_schema.Table.indexes`
"""
- return set(
+ return {
fkc.constraint
for fkc in self.foreign_keys
if fkc.constraint is not None
- )
+ }
def _init_existing(self, *args: Any, **kwargs: Any) -> None:
autoload_with = kwargs.pop("autoload_with", None)
@@ -1895,7 +1895,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
# name = None is expected to be an interim state
# note this use case is legacy now that ORM declarative has a
# dedicated "column" construct local to the ORM
- super(Column, self).__init__(name, type_) # type: ignore
+ super().__init__(name, type_) # type: ignore
self.key = key if key is not None else name # type: ignore
self.primary_key = primary_key
@@ -3573,7 +3573,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
column = parent
assert isinstance(column, Column)
- super(Sequence, self)._set_parent(column)
+ super()._set_parent(column)
column._on_table_attach(self._set_table)
def _copy(self) -> Sequence:
@@ -3712,7 +3712,7 @@ class DefaultClause(FetchedValue):
_reflected: bool = False,
) -> None:
util.assert_arg_type(arg, (str, ClauseElement, TextClause), "arg")
- super(DefaultClause, self).__init__(for_update)
+ super().__init__(for_update)
self.arg = arg
self.reflected = _reflected
@@ -3914,9 +3914,9 @@ class ColumnCollectionMixin:
# issue #3411 - don't do the per-column auto-attach if some of the
# columns are specified as strings.
- has_string_cols = set(
+ has_string_cols = {
c for c in self._pending_colargs if c is not None
- ).difference(col_objs)
+ }.difference(col_objs)
if not has_string_cols:
def _col_attached(column: Column[Any], table: Table) -> None:
@@ -4434,7 +4434,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
return self.elements[0].column.table
def _validate_dest_table(self, table: Table) -> None:
- table_keys = set([elem._table_key() for elem in self.elements])
+ table_keys = {elem._table_key() for elem in self.elements}
if None not in table_keys and len(table_keys) > 1:
elem0, elem1 = sorted(table_keys)[0:2]
raise exc.ArgumentError(
@@ -4624,7 +4624,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
**dialect_kw: Any,
) -> None:
self._implicit_generated = _implicit_generated
- super(PrimaryKeyConstraint, self).__init__(
+ super().__init__(
*columns,
name=name,
deferrable=deferrable,
@@ -4636,7 +4636,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
table = parent
assert isinstance(table, Table)
- super(PrimaryKeyConstraint, self)._set_parent(table)
+ super()._set_parent(table)
if table.primary_key is not self:
table.constraints.discard(table.primary_key)
@@ -5219,13 +5219,9 @@ class MetaData(HasSchemaAttr):
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
if self._schemas:
- self._schemas = set(
- [
- t.schema
- for t in self.tables.values()
- if t.schema is not None
- ]
- )
+ self._schemas = {
+ t.schema for t in self.tables.values() if t.schema is not None
+ }
def __getstate__(self) -> Dict[str, Any]:
return {
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 8c64dea9d..fcffc324f 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -1301,12 +1301,12 @@ class Join(roles.DMLTableRole, FromClause):
# run normal _copy_internals. the clones for
# left and right will come from the clone function's
# cache
- super(Join, self)._copy_internals(clone=clone, **kw)
+ super()._copy_internals(clone=clone, **kw)
self._reset_memoizations()
def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
- super(Join, self)._refresh_for_new_column(column)
+ super()._refresh_for_new_column(column)
self.left._refresh_for_new_column(column)
self.right._refresh_for_new_column(column)
@@ -1467,7 +1467,7 @@ class Join(roles.DMLTableRole, FromClause):
# "consider_as_foreign_keys".
if consider_as_foreign_keys:
for const in list(constraints):
- if set(f.parent for f in const.elements) != set(
+ if {f.parent for f in const.elements} != set(
consider_as_foreign_keys
):
del constraints[const]
@@ -1475,7 +1475,7 @@ class Join(roles.DMLTableRole, FromClause):
# if still multiple constraints, but
# they all refer to the exact same end result, use it.
if len(constraints) > 1:
- dedupe = set(tuple(crit) for crit in constraints.values())
+ dedupe = {tuple(crit) for crit in constraints.values()}
if len(dedupe) == 1:
key = list(constraints)[0]
constraints = {key: constraints[key]}
@@ -1621,7 +1621,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
self.name = name
def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
- super(AliasedReturnsRows, self)._refresh_for_new_column(column)
+ super()._refresh_for_new_column(column)
self.element._refresh_for_new_column(column)
def _populate_column_collection(self):
@@ -1654,7 +1654,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
) -> None:
existing_element = self.element
- super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw)
+ super()._copy_internals(clone=clone, **kw)
# the element clone is usually against a Table that returns the
# same object. don't reset exported .c. collections and other
@@ -1752,7 +1752,7 @@ class TableValuedAlias(LateralFromClause, Alias):
table_value_type=None,
joins_implicitly=False,
):
- super(TableValuedAlias, self)._init(selectable, name=name)
+ super()._init(selectable, name=name)
self.joins_implicitly = joins_implicitly
self._tableval_type = (
@@ -1959,7 +1959,7 @@ class TableSample(FromClauseAlias):
self.sampling = sampling
self.seed = seed
- super(TableSample, self)._init(selectable, name=name)
+ super()._init(selectable, name=name)
def _get_method(self):
return self.sampling
@@ -2044,7 +2044,7 @@ class CTE(
self._prefixes = _prefixes
if _suffixes:
self._suffixes = _suffixes
- super(CTE, self)._init(selectable, name=name)
+ super()._init(selectable, name=name)
def _populate_column_collection(self):
if self._cte_alias is not None:
@@ -2945,7 +2945,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
return None
def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any):
- super(TableClause, self).__init__()
+ super().__init__()
self.name = name
self._columns = DedupeColumnCollection()
self.primary_key = ColumnSet() # type: ignore
@@ -3156,7 +3156,7 @@ class Values(Generative, LateralFromClause):
name: Optional[str] = None,
literal_binds: bool = False,
):
- super(Values, self).__init__()
+ super().__init__()
self._column_args = columns
if name is None:
self._unnamed = True
@@ -4188,7 +4188,7 @@ class CompoundSelectState(CompileState):
# TODO: this is hacky and slow
hacky_subquery = self.statement.subquery()
hacky_subquery.named_with_column = False
- d = dict((c.key, c) for c in hacky_subquery.c)
+ d = {c.key: c for c in hacky_subquery.c}
return d, d, d
@@ -4369,7 +4369,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
)
def _refresh_for_new_column(self, column):
- super(CompoundSelect, self)._refresh_for_new_column(column)
+ super()._refresh_for_new_column(column)
for select in self.selects:
select._refresh_for_new_column(column)
@@ -4689,16 +4689,16 @@ class SelectState(util.MemoizedSlots, CompileState):
Dict[str, ColumnElement[Any]],
Dict[str, ColumnElement[Any]],
]:
- with_cols: Dict[str, ColumnElement[Any]] = dict(
- (c._tq_label or c.key, c) # type: ignore
+ with_cols: Dict[str, ColumnElement[Any]] = {
+ c._tq_label or c.key: c # type: ignore
for c in self.statement._all_selected_columns
if c._allow_label_resolve
- )
- only_froms: Dict[str, ColumnElement[Any]] = dict(
- (c.key, c) # type: ignore
+ }
+ only_froms: Dict[str, ColumnElement[Any]] = {
+ c.key: c # type: ignore
for c in _select_iterables(self.froms)
if c._allow_label_resolve
- )
+ }
only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy()
for key, value in only_froms.items():
with_cols.setdefault(key, value)
@@ -5569,7 +5569,7 @@ class Select(
# 2. copy FROM collections, adding in joins that we've created.
existing_from_obj = [clone(f, **kw) for f in self._from_obj]
add_froms = (
- set(f for f in new_froms.values() if isinstance(f, Join))
+ {f for f in new_froms.values() if isinstance(f, Join)}
.difference(all_the_froms)
.difference(existing_from_obj)
)
@@ -5589,15 +5589,13 @@ class Select(
# correlate_except, setup_joins, these clone normally. For
# column-expression oriented things like raw_columns, where_criteria,
# order by, we get this from the new froms.
- super(Select, self)._copy_internals(
- clone=clone, omit_attrs=("_from_obj",), **kw
- )
+ super()._copy_internals(clone=clone, omit_attrs=("_from_obj",), **kw)
self._reset_memoizations()
def get_children(self, **kw: Any) -> Iterable[ClauseElement]:
return itertools.chain(
- super(Select, self).get_children(
+ super().get_children(
omit_attrs=("_from_obj", "_correlate", "_correlate_except"),
**kw,
),
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index b98a16b6f..624b7d16e 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -134,9 +134,7 @@ class Concatenable(TypeEngineMixin):
):
return operators.concat_op, self.expr.type
else:
- return super(Concatenable.Comparator, self)._adapt_expression(
- op, other_comparator
- )
+ return super()._adapt_expression(op, other_comparator)
comparator_factory: _ComparatorFactory[Any] = Comparator
@@ -319,7 +317,7 @@ class Unicode(String):
Parameters are the same as that of :class:`.String`.
"""
- super(Unicode, self).__init__(length=length, **kwargs)
+ super().__init__(length=length, **kwargs)
class UnicodeText(Text):
@@ -344,7 +342,7 @@ class UnicodeText(Text):
Parameters are the same as that of :class:`_expression.TextClause`.
"""
- super(UnicodeText, self).__init__(length=length, **kwargs)
+ super().__init__(length=length, **kwargs)
class Integer(HasExpressionLookup, TypeEngine[int]):
@@ -930,7 +928,7 @@ class _Binary(TypeEngine[bytes]):
if isinstance(value, str):
return self
else:
- return super(_Binary, self).coerce_compared_value(op, value)
+ return super().coerce_compared_value(op, value)
def get_dbapi_type(self, dbapi):
return dbapi.BINARY
@@ -1450,7 +1448,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
self._valid_lookup[None] = self._object_lookup[None] = None
- super(Enum, self).__init__(length=length)
+ super().__init__(length=length)
if self.enum_class:
kw.setdefault("name", self.enum_class.__name__.lower())
@@ -1551,9 +1549,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
op: OperatorType,
other_comparator: TypeEngine.Comparator[Any],
) -> Tuple[OperatorType, TypeEngine[Any]]:
- op, typ = super(Enum.Comparator, self)._adapt_expression(
- op, other_comparator
- )
+ op, typ = super()._adapt_expression(op, other_comparator)
if op is operators.concat_op:
typ = String(self.type.length)
return op, typ
@@ -1618,7 +1614,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
def adapt(self, impltype, **kw):
kw["_enums"] = self._enums_argument
kw["_disable_warnings"] = True
- return super(Enum, self).adapt(impltype, **kw)
+ return super().adapt(impltype, **kw)
def _should_create_constraint(self, compiler, **kw):
if not self._is_impl_for_variant(compiler.dialect, kw):
@@ -1649,7 +1645,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
assert e.table is table
def literal_processor(self, dialect):
- parent_processor = super(Enum, self).literal_processor(dialect)
+ parent_processor = super().literal_processor(dialect)
def process(value):
value = self._db_value_for_elem(value)
@@ -1660,7 +1656,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
return process
def bind_processor(self, dialect):
- parent_processor = super(Enum, self).bind_processor(dialect)
+ parent_processor = super().bind_processor(dialect)
def process(value):
value = self._db_value_for_elem(value)
@@ -1671,7 +1667,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
return process
def result_processor(self, dialect, coltype):
- parent_processor = super(Enum, self).result_processor(dialect, coltype)
+ parent_processor = super().result_processor(dialect, coltype)
def process(value):
if parent_processor:
@@ -1690,7 +1686,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
if self.enum_class:
return self.enum_class
else:
- return super(Enum, self).python_type
+ return super().python_type
class PickleType(TypeDecorator[object]):
@@ -1739,7 +1735,7 @@ class PickleType(TypeDecorator[object]):
self.protocol = protocol
self.pickler = pickler or pickle
self.comparator = comparator
- super(PickleType, self).__init__()
+ super().__init__()
if impl:
# custom impl is not necessarily a LargeBinary subclass.
@@ -2000,7 +1996,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
support a "day precision" parameter, i.e. Oracle.
"""
- super(Interval, self).__init__()
+ super().__init__()
self.native = native
self.second_precision = second_precision
self.day_precision = day_precision
@@ -3005,7 +3001,7 @@ class ARRAY(
def _set_parent_with_dispatch(self, parent):
"""Support SchemaEventTarget"""
- super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True)
+ super()._set_parent_with_dispatch(parent, outer=True)
if isinstance(self.item_type, SchemaEventTarget):
self.item_type._set_parent_with_dispatch(parent)
@@ -3249,7 +3245,7 @@ class TIMESTAMP(DateTime):
"""
- super(TIMESTAMP, self).__init__(timezone=timezone)
+ super().__init__(timezone=timezone)
def get_dbapi_type(self, dbapi):
return dbapi.TIMESTAMP
@@ -3464,7 +3460,7 @@ class Uuid(TypeEngine[_UUID_RETURN]):
@overload
def __init__(
- self: "Uuid[_python_UUID]",
+ self: Uuid[_python_UUID],
as_uuid: Literal[True] = ...,
native_uuid: bool = ...,
):
@@ -3472,7 +3468,7 @@ class Uuid(TypeEngine[_UUID_RETURN]):
@overload
def __init__(
- self: "Uuid[str]",
+ self: Uuid[str],
as_uuid: Literal[False] = ...,
native_uuid: bool = ...,
):
@@ -3628,11 +3624,11 @@ class UUID(Uuid[_UUID_RETURN]):
__visit_name__ = "UUID"
@overload
- def __init__(self: "UUID[_python_UUID]", as_uuid: Literal[True] = ...):
+ def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...):
...
@overload
- def __init__(self: "UUID[str]", as_uuid: Literal[False] = ...):
+ def __init__(self: UUID[str], as_uuid: Literal[False] = ...):
...
def __init__(self, as_uuid: bool = True):
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 135407321..866c0ccde 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -301,9 +301,7 @@ class _CopyInternalsTraversal(HasTraversalDispatch):
def visit_string_clauseelement_dict(
self, attrname, parent, element, clone=_clone, **kw
):
- return dict(
- (key, clone(value, **kw)) for key, value in element.items()
- )
+ return {key: clone(value, **kw) for key, value in element.items()}
def visit_setup_join_tuple(
self, attrname, parent, element, clone=_clone, **kw
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index cd57ee3b6..c3768c6c6 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -1399,7 +1399,7 @@ class Emulated(TypeEngineMixin):
def _is_native_for_emulated(
typ: Type[Union[TypeEngine[Any], TypeEngineMixin]],
-) -> TypeGuard["Type[NativeForEmulated]"]:
+) -> TypeGuard[Type[NativeForEmulated]]:
return hasattr(typ, "adapt_emulated_to_native")
@@ -1673,9 +1673,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
if TYPE_CHECKING:
assert isinstance(self.expr.type, TypeDecorator)
kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
- return super(TypeDecorator.Comparator, self).operate(
- op, *other, **kwargs
- )
+ return super().operate(op, *other, **kwargs)
def reverse_operate(
self, op: OperatorType, other: Any, **kwargs: Any
@@ -1683,9 +1681,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
if TYPE_CHECKING:
assert isinstance(self.expr.type, TypeDecorator)
kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
- return super(TypeDecorator.Comparator, self).reverse_operate(
- op, other, **kwargs
- )
+ return super().reverse_operate(op, other, **kwargs)
@property
def comparator_factory( # type: ignore # mypy properties bug
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index ec8ea757f..14cbe2456 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -316,8 +316,7 @@ def visit_binary_product(
if isinstance(element, ColumnClause):
yield element
for elem in element.get_children():
- for e in visit(elem):
- yield e
+ yield from visit(elem)
list(visit(expr))
visit = None # type: ignore # remove gc cycles
@@ -433,12 +432,10 @@ def expand_column_list_from_order_by(collist, order_by):
in the collist.
"""
- cols_already_present = set(
- [
- col.element if col._order_by_label_element is not None else col
- for col in collist
- ]
- )
+ cols_already_present = {
+ col.element if col._order_by_label_element is not None else col
+ for col in collist
+ }
to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
@@ -463,13 +460,10 @@ def clause_is_present(clause, search):
def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
if isinstance(clause, Join):
- for t in tables_from_leftmost(clause.left):
- yield t
- for t in tables_from_leftmost(clause.right):
- yield t
+ yield from tables_from_leftmost(clause.left)
+ yield from tables_from_leftmost(clause.right)
elif isinstance(clause, FromGrouping):
- for t in tables_from_leftmost(clause.element):
- yield t
+ yield from tables_from_leftmost(clause.element)
else:
yield clause
@@ -592,7 +586,7 @@ class _repr_row(_repr_base):
__slots__ = ("row",)
- def __init__(self, row: "Row[Any]", max_chars: int = 300):
+ def __init__(self, row: Row[Any], max_chars: int = 300):
self.row = row
self.max_chars = max_chars
@@ -775,7 +769,7 @@ class _repr_params(_repr_base):
)
return text
- def _repr_param_tuple(self, params: "Sequence[Any]") -> str:
+ def _repr_param_tuple(self, params: Sequence[Any]) -> str:
trunc = self.trunc
(