diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/_elements_constructors.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 243 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/operators.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 30 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 46 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 42 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 26 |
17 files changed, 226 insertions, 269 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 8b8f6b010..7c5281bee 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -1127,7 +1127,7 @@ def label( name: str, element: _ColumnExpressionArgument[_T], type_: Optional[_TypeEngineArgument[_T]] = None, -) -> "Label[_T]": +) -> Label[_T]: """Return a :class:`Label` object for the given :class:`_expression.ColumnElement`. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 34b295113..c81891169 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -291,16 +291,14 @@ def _cloned_intersection(a, b): """ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if all_overlap.intersection(elem._cloned_set) - ) + return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} def _cloned_difference(a, b): all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( + return { elem for elem in a if not all_overlap.intersection(elem._cloned_set) - ) + } class _DialectArgView(MutableMapping[str, Any]): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 8074bcf8b..f48a3ccb0 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -782,7 +782,7 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl): else: advice = None - return super(ExpressionElementImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) @@ -1096,7 +1096,7 @@ class LabeledColumnExprImpl(ExpressionElementImpl): if isinstance(resolved, roles.ExpressionElementRole): return resolved.label(None) else: - new = super(LabeledColumnExprImpl, self)._implicit_coercions( + new = super()._implicit_coercions( element, resolved, argname=argname, **kw ) if isinstance(new, roles.ExpressionElementRole): @@ -1123,7 +1123,7 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): f"{', '.join(repr(e) for e in element)})?" ) - return super(ColumnsClauseImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) @@ -1370,7 +1370,7 @@ class CompoundElementImpl(_NoTextCoercion, RoleImpl): ) else: advice = None - return super(CompoundElementImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9a00afc91..17aafddad 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -115,104 +115,102 @@ if typing.TYPE_CHECKING: _FromHintsType = Dict["FromClause", str] -RESERVED_WORDS = set( - [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "authorization", - "between", - "binary", - "both", - "case", - "cast", - "check", - "collate", - "column", - "constraint", - "create", - "cross", - "current_date", - "current_role", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "for", - "foreign", - "freeze", - "from", - "full", - "grant", - "group", - "having", - "ilike", - "in", - "initially", - "inner", - "intersect", - "into", - "is", - "isnull", - "join", - "leading", - "left", - "like", - "limit", - "localtime", - "localtimestamp", - "natural", - "new", - "not", - "notnull", - "null", - "off", - "offset", - "old", - "on", - "only", - "or", - "order", - "outer", - "overlaps", - "placing", - "primary", - "references", - "right", - "select", - "session_user", - "set", - "similar", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "verbose", - "when", - "where", - ] -) +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", +} LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I) @@ -505,8 +503,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): "between each element to resolve." ) froms_str = ", ".join( - '"{elem}"'.format(elem=self.froms[from_]) - for from_ in froms + f'"{self.froms[from_]}"' for from_ in froms ) message = template.format( froms=froms_str, start=self.froms[start_with] @@ -1259,11 +1256,8 @@ class SQLCompiler(Compiled): # mypy is not able to see the two value types as the above Union, # it just sees "object". don't know how to resolve - return dict( - ( - key, - value, - ) # type: ignore + return { + key: value # type: ignore for key, value in ( ( self.bind_names[bindparam], @@ -1277,7 +1271,7 @@ class SQLCompiler(Compiled): for bindparam in self.bind_names ) if value is not None - ) + } def is_subquery(self): return len(self.stack) > 1 @@ -4147,17 +4141,12 @@ class SQLCompiler(Compiled): def _setup_select_hints( self, select: Select[Any] ) -> Tuple[str, _FromHintsType]: - byfrom = dict( - [ - ( - from_, - hinttext - % {"name": from_._compiler_dispatch(self, ashint=True)}, - ) - for (from_, dialect), hinttext in select._hints.items() - if dialect in ("*", self.dialect.name) - ] - ) + byfrom = { + from_: hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)} + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + } hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom @@ -4583,13 +4572,11 @@ class SQLCompiler(Compiled): ) def _setup_crud_hints(self, stmt, table_text): - dialect_hints = dict( - [ - (table, hint_text) - for (table, dialect), hint_text in stmt._hints.items() - if dialect in ("*", self.dialect.name) - ] - ) + dialect_hints = { + table: hint_text + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + } if stmt.table in dialect_hints: table_text = self.format_from_hint_text( table_text, stmt.table, dialect_hints[stmt.table], True @@ -5318,9 +5305,7 @@ class StrSQLCompiler(SQLCompiler): if not isinstance(compiler, StrSQLCompiler): return compiler.process(element) - return super(StrSQLCompiler, self).visit_unsupported_compilation( - element, err - ) + return super().visit_unsupported_compilation(element, err) def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( @@ -6603,14 +6588,14 @@ class IdentifierPreparer: @util.memoized_property def _r_identifiers(self): - initial, final, escaped_final = [ + initial, final, escaped_final = ( re.escape(s) for s in ( self.initial_quote, self.final_quote, self._escape_identifier(self.final_quote), ) - ] + ) r = re.compile( r"(?:" r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 31d127c2c..017ff7baa 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -227,15 +227,15 @@ def _get_crud_params( parameters = {} elif stmt_parameter_tuples: assert spd is not None - parameters = dict( - (_column_as_key(key), REQUIRED) + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys if key not in spd - ) + } else: - parameters = dict( - (_column_as_key(key), REQUIRED) for key in compiler.column_keys - ) + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys + } # create a list of column assignment clauses as tuples values: List[_CrudParamElement] = [] @@ -1278,10 +1278,10 @@ def _get_update_multitable_params( values, kw, ): - normalized_params = dict( - (coercions.expect(roles.DMLColumnRole, c), param) + normalized_params = { + coercions.expect(roles.DMLColumnRole, c): param for c, param in stmt_parameter_tuples - ) + } include_table = compile_state.include_table_with_column_exprs diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index fa0c25b1d..ecdc2eb63 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -176,7 +176,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): """ _ddl_if: Optional[DDLIf] = None - target: Optional["SchemaItem"] = None + target: Optional[SchemaItem] = None def _execute_on_connection( self, connection, distilled_params, execution_options @@ -1179,12 +1179,10 @@ class SchemaDropper(InvokeDropDDLBase): def sort_tables( - tables: Iterable["Table"], - skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, - extra_dependencies: Optional[ - typing_Sequence[Tuple["Table", "Table"]] - ] = None, -) -> List["Table"]: + tables: Iterable[Table], + skip_fn: Optional[Callable[[ForeignKeyConstraint], bool]] = None, + extra_dependencies: Optional[typing_Sequence[Tuple[Table, Table]]] = None, +) -> List[Table]: """Sort a collection of :class:`_schema.Table` objects based on dependency. diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 2d3e3598b..c279e344b 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -1179,7 +1179,7 @@ class Insert(ValuesBase): ) def __init__(self, table: _DMLTableArgument): - super(Insert, self).__init__(table) + super().__init__(table) @_generative def inline(self: SelfInsert) -> SelfInsert: @@ -1498,7 +1498,7 @@ class Update(DMLWhereBase, ValuesBase): ) def __init__(self, table: _DMLTableArgument): - super(Update, self).__init__(table) + super().__init__(table) @_generative def ordered_values( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 044bdf585..d9a1a9358 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3035,7 +3035,7 @@ class BooleanClauseList(ExpressionClauseList[bool]): if not self.clauses: return self else: - return super(BooleanClauseList, self).self_group(against=against) + return super().self_group(against=against) and_ = BooleanClauseList.and_ @@ -3082,7 +3082,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): ] self.type = sqltypes.TupleType(*[arg.type for arg in init_clauses]) - super(Tuple, self).__init__(*init_clauses) + super().__init__(*init_clauses) @property def _select_iterable(self) -> _SelectIterable: @@ -3753,8 +3753,8 @@ class BinaryExpression(OperatorExpression[_T]): if typing.TYPE_CHECKING: def __invert__( - self: "BinaryExpression[_T]", - ) -> "BinaryExpression[_T]": + self: BinaryExpression[_T], + ) -> BinaryExpression[_T]: ... @util.ro_non_memoized_property @@ -3772,7 +3772,7 @@ class BinaryExpression(OperatorExpression[_T]): modifiers=self.modifiers, ) else: - return super(BinaryExpression, self)._negate() + return super()._negate() class Slice(ColumnElement[Any]): @@ -4617,7 +4617,7 @@ class ColumnClause( if self.table is not None: return self.table.entity_namespace else: - return super(ColumnClause, self).entity_namespace + return super().entity_namespace def _clone(self, detect_subquery_cols=False, **kw): if ( @@ -4630,7 +4630,7 @@ class ColumnClause( new = table.c.corresponding_column(self) return new - return super(ColumnClause, self)._clone(**kw) + return super()._clone(**kw) @HasMemoized_ro_memoized_attribute def _from_objects(self) -> List[FromClause]: @@ -4993,7 +4993,7 @@ class AnnotatedColumnElement(Annotated): self.__dict__.pop(attr) def _with_annotations(self, values): - clone = super(AnnotatedColumnElement, self)._with_annotations(values) + clone = super()._with_annotations(values) clone.__dict__.pop("comparator", None) return clone @@ -5032,7 +5032,7 @@ class _truncated_label(quoted_name): def __new__(cls, value: str, quote: Optional[bool] = None) -> Any: quote = getattr(value, "quote", quote) # return super(_truncated_label, cls).__new__(cls, value, quote, True) - return super(_truncated_label, cls).__new__(cls, value, quote) + return super().__new__(cls, value, quote) def __reduce__(self) -> Any: return self.__class__, (str(self), self.quote) diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fad7c28eb..5ed89bc82 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -167,9 +167,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): @property def _proxy_key(self): - return super(FunctionElement, self)._proxy_key or getattr( - self, "name", None - ) + return super()._proxy_key or getattr(self, "name", None) def _execute_on_connection( self, @@ -660,7 +658,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ): return Grouping(self) else: - return super(FunctionElement, self).self_group(against=against) + return super().self_group(against=against) @property def entity_namespace(self): @@ -1198,7 +1196,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]): ] kwargs.setdefault("type_", _type_from_args(fn_args)) kwargs["_parsed_args"] = fn_args - super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs) + super().__init__(*fn_args, **kwargs) class coalesce(ReturnTypeFromArgs[_T]): @@ -1304,7 +1302,7 @@ class count(GenericFunction[int]): def __init__(self, expression=None, **kwargs): if expression is None: expression = literal_column("*") - super(count, self).__init__(expression, **kwargs) + super().__init__(expression, **kwargs) class current_date(AnsiFunction[datetime.date]): @@ -1411,7 +1409,7 @@ class array_agg(GenericFunction[_T]): type_from_args, dimensions=1 ) kwargs["_parsed_args"] = fn_args - super(array_agg, self).__init__(*fn_args, **kwargs) + super().__init__(*fn_args, **kwargs) class OrderedSetAgg(GenericFunction[_T]): diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index bbfaf47e1..26e3a21bb 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -439,7 +439,7 @@ class DeferredLambdaElement(LambdaElement): lambda_args: Tuple[Any, ...] = (), ): self.lambda_args = lambda_args - super(DeferredLambdaElement, self).__init__(fn, role, opts) + super().__init__(fn, role, opts) def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) @@ -483,7 +483,7 @@ class DeferredLambdaElement(LambdaElement): def _copy_internals( self, clone=_clone, deferred_copy_internals=None, **kw ): - super(DeferredLambdaElement, self)._copy_internals( + super()._copy_internals( clone=clone, deferred_copy_internals=deferred_copy_internals, # **kw opts=kw, diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 55c275741..2d1f9caa1 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -66,11 +66,11 @@ class OperatorType(Protocol): def __call__( self, - left: "Operators", + left: Operators, right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> "Operators": + ) -> Operators: ... @@ -184,7 +184,7 @@ class Operators: precedence: int = 0, is_comparison: bool = False, return_type: Optional[ - Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"] + Union[Type[TypeEngine[Any]], TypeEngine[Any]] ] = None, python_impl: Optional[Callable[..., Any]] = None, ) -> Callable[[Any], Operators]: @@ -397,7 +397,7 @@ class custom_op(OperatorType, Generic[_T]): precedence: int = 0, is_comparison: bool = False, return_type: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + Union[Type[TypeEngine[_T]], TypeEngine[_T]] ] = None, natural_self_precedent: bool = False, eager_grouping: bool = False, diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index cd10d0c4a..f76fc447c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -920,11 +920,11 @@ class Table( :attr:`_schema.Table.indexes` """ - return set( + return { fkc.constraint for fkc in self.foreign_keys if fkc.constraint is not None - ) + } def _init_existing(self, *args: Any, **kwargs: Any) -> None: autoload_with = kwargs.pop("autoload_with", None) @@ -1895,7 +1895,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): # name = None is expected to be an interim state # note this use case is legacy now that ORM declarative has a # dedicated "column" construct local to the ORM - super(Column, self).__init__(name, type_) # type: ignore + super().__init__(name, type_) # type: ignore self.key = key if key is not None else name # type: ignore self.primary_key = primary_key @@ -3573,7 +3573,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: column = parent assert isinstance(column, Column) - super(Sequence, self)._set_parent(column) + super()._set_parent(column) column._on_table_attach(self._set_table) def _copy(self) -> Sequence: @@ -3712,7 +3712,7 @@ class DefaultClause(FetchedValue): _reflected: bool = False, ) -> None: util.assert_arg_type(arg, (str, ClauseElement, TextClause), "arg") - super(DefaultClause, self).__init__(for_update) + super().__init__(for_update) self.arg = arg self.reflected = _reflected @@ -3914,9 +3914,9 @@ class ColumnCollectionMixin: # issue #3411 - don't do the per-column auto-attach if some of the # columns are specified as strings. - has_string_cols = set( + has_string_cols = { c for c in self._pending_colargs if c is not None - ).difference(col_objs) + }.difference(col_objs) if not has_string_cols: def _col_attached(column: Column[Any], table: Table) -> None: @@ -4434,7 +4434,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.elements[0].column.table def _validate_dest_table(self, table: Table) -> None: - table_keys = set([elem._table_key() for elem in self.elements]) + table_keys = {elem._table_key() for elem in self.elements} if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( @@ -4624,7 +4624,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): **dialect_kw: Any, ) -> None: self._implicit_generated = _implicit_generated - super(PrimaryKeyConstraint, self).__init__( + super().__init__( *columns, name=name, deferrable=deferrable, @@ -4636,7 +4636,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: table = parent assert isinstance(table, Table) - super(PrimaryKeyConstraint, self)._set_parent(table) + super()._set_parent(table) if table.primary_key is not self: table.constraints.discard(table.primary_key) @@ -5219,13 +5219,9 @@ class MetaData(HasSchemaAttr): for fk in removed.foreign_keys: fk._remove_from_metadata(self) if self._schemas: - self._schemas = set( - [ - t.schema - for t in self.tables.values() - if t.schema is not None - ] - ) + self._schemas = { + t.schema for t in self.tables.values() if t.schema is not None + } def __getstate__(self) -> Dict[str, Any]: return { diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8c64dea9d..fcffc324f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1301,12 +1301,12 @@ class Join(roles.DMLTableRole, FromClause): # run normal _copy_internals. the clones for # left and right will come from the clone function's # cache - super(Join, self)._copy_internals(clone=clone, **kw) + super()._copy_internals(clone=clone, **kw) self._reset_memoizations() def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: - super(Join, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) @@ -1467,7 +1467,7 @@ class Join(roles.DMLTableRole, FromClause): # "consider_as_foreign_keys". if consider_as_foreign_keys: for const in list(constraints): - if set(f.parent for f in const.elements) != set( + if {f.parent for f in const.elements} != set( consider_as_foreign_keys ): del constraints[const] @@ -1475,7 +1475,7 @@ class Join(roles.DMLTableRole, FromClause): # if still multiple constraints, but # they all refer to the exact same end result, use it. if len(constraints) > 1: - dedupe = set(tuple(crit) for crit in constraints.values()) + dedupe = {tuple(crit) for crit in constraints.values()} if len(dedupe) == 1: key = list(constraints)[0] constraints = {key: constraints[key]} @@ -1621,7 +1621,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): self.name = name def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: - super(AliasedReturnsRows, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) self.element._refresh_for_new_column(column) def _populate_column_collection(self): @@ -1654,7 +1654,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): ) -> None: existing_element = self.element - super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) + super()._copy_internals(clone=clone, **kw) # the element clone is usually against a Table that returns the # same object. don't reset exported .c. collections and other @@ -1752,7 +1752,7 @@ class TableValuedAlias(LateralFromClause, Alias): table_value_type=None, joins_implicitly=False, ): - super(TableValuedAlias, self)._init(selectable, name=name) + super()._init(selectable, name=name) self.joins_implicitly = joins_implicitly self._tableval_type = ( @@ -1959,7 +1959,7 @@ class TableSample(FromClauseAlias): self.sampling = sampling self.seed = seed - super(TableSample, self)._init(selectable, name=name) + super()._init(selectable, name=name) def _get_method(self): return self.sampling @@ -2044,7 +2044,7 @@ class CTE( self._prefixes = _prefixes if _suffixes: self._suffixes = _suffixes - super(CTE, self)._init(selectable, name=name) + super()._init(selectable, name=name) def _populate_column_collection(self): if self._cte_alias is not None: @@ -2945,7 +2945,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): return None def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any): - super(TableClause, self).__init__() + super().__init__() self.name = name self._columns = DedupeColumnCollection() self.primary_key = ColumnSet() # type: ignore @@ -3156,7 +3156,7 @@ class Values(Generative, LateralFromClause): name: Optional[str] = None, literal_binds: bool = False, ): - super(Values, self).__init__() + super().__init__() self._column_args = columns if name is None: self._unnamed = True @@ -4188,7 +4188,7 @@ class CompoundSelectState(CompileState): # TODO: this is hacky and slow hacky_subquery = self.statement.subquery() hacky_subquery.named_with_column = False - d = dict((c.key, c) for c in hacky_subquery.c) + d = {c.key: c for c in hacky_subquery.c} return d, d, d @@ -4369,7 +4369,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): ) def _refresh_for_new_column(self, column): - super(CompoundSelect, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) for select in self.selects: select._refresh_for_new_column(column) @@ -4689,16 +4689,16 @@ class SelectState(util.MemoizedSlots, CompileState): Dict[str, ColumnElement[Any]], Dict[str, ColumnElement[Any]], ]: - with_cols: Dict[str, ColumnElement[Any]] = dict( - (c._tq_label or c.key, c) # type: ignore + with_cols: Dict[str, ColumnElement[Any]] = { + c._tq_label or c.key: c # type: ignore for c in self.statement._all_selected_columns if c._allow_label_resolve - ) - only_froms: Dict[str, ColumnElement[Any]] = dict( - (c.key, c) # type: ignore + } + only_froms: Dict[str, ColumnElement[Any]] = { + c.key: c # type: ignore for c in _select_iterables(self.froms) if c._allow_label_resolve - ) + } only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) @@ -5569,7 +5569,7 @@ class Select( # 2. copy FROM collections, adding in joins that we've created. existing_from_obj = [clone(f, **kw) for f in self._from_obj] add_froms = ( - set(f for f in new_froms.values() if isinstance(f, Join)) + {f for f in new_froms.values() if isinstance(f, Join)} .difference(all_the_froms) .difference(existing_from_obj) ) @@ -5589,15 +5589,13 @@ class Select( # correlate_except, setup_joins, these clone normally. For # column-expression oriented things like raw_columns, where_criteria, # order by, we get this from the new froms. - super(Select, self)._copy_internals( - clone=clone, omit_attrs=("_from_obj",), **kw - ) + super()._copy_internals(clone=clone, omit_attrs=("_from_obj",), **kw) self._reset_memoizations() def get_children(self, **kw: Any) -> Iterable[ClauseElement]: return itertools.chain( - super(Select, self).get_children( + super().get_children( omit_attrs=("_from_obj", "_correlate", "_correlate_except"), **kw, ), diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b98a16b6f..624b7d16e 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -134,9 +134,7 @@ class Concatenable(TypeEngineMixin): ): return operators.concat_op, self.expr.type else: - return super(Concatenable.Comparator, self)._adapt_expression( - op, other_comparator - ) + return super()._adapt_expression(op, other_comparator) comparator_factory: _ComparatorFactory[Any] = Comparator @@ -319,7 +317,7 @@ class Unicode(String): Parameters are the same as that of :class:`.String`. """ - super(Unicode, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class UnicodeText(Text): @@ -344,7 +342,7 @@ class UnicodeText(Text): Parameters are the same as that of :class:`_expression.TextClause`. """ - super(UnicodeText, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class Integer(HasExpressionLookup, TypeEngine[int]): @@ -930,7 +928,7 @@ class _Binary(TypeEngine[bytes]): if isinstance(value, str): return self else: - return super(_Binary, self).coerce_compared_value(op, value) + return super().coerce_compared_value(op, value) def get_dbapi_type(self, dbapi): return dbapi.BINARY @@ -1450,7 +1448,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self._valid_lookup[None] = self._object_lookup[None] = None - super(Enum, self).__init__(length=length) + super().__init__(length=length) if self.enum_class: kw.setdefault("name", self.enum_class.__name__.lower()) @@ -1551,9 +1549,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): op: OperatorType, other_comparator: TypeEngine.Comparator[Any], ) -> Tuple[OperatorType, TypeEngine[Any]]: - op, typ = super(Enum.Comparator, self)._adapt_expression( - op, other_comparator - ) + op, typ = super()._adapt_expression(op, other_comparator) if op is operators.concat_op: typ = String(self.type.length) return op, typ @@ -1618,7 +1614,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): def adapt(self, impltype, **kw): kw["_enums"] = self._enums_argument kw["_disable_warnings"] = True - return super(Enum, self).adapt(impltype, **kw) + return super().adapt(impltype, **kw) def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): @@ -1649,7 +1645,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): assert e.table is table def literal_processor(self, dialect): - parent_processor = super(Enum, self).literal_processor(dialect) + parent_processor = super().literal_processor(dialect) def process(value): value = self._db_value_for_elem(value) @@ -1660,7 +1656,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return process def bind_processor(self, dialect): - parent_processor = super(Enum, self).bind_processor(dialect) + parent_processor = super().bind_processor(dialect) def process(value): value = self._db_value_for_elem(value) @@ -1671,7 +1667,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return process def result_processor(self, dialect, coltype): - parent_processor = super(Enum, self).result_processor(dialect, coltype) + parent_processor = super().result_processor(dialect, coltype) def process(value): if parent_processor: @@ -1690,7 +1686,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): if self.enum_class: return self.enum_class else: - return super(Enum, self).python_type + return super().python_type class PickleType(TypeDecorator[object]): @@ -1739,7 +1735,7 @@ class PickleType(TypeDecorator[object]): self.protocol = protocol self.pickler = pickler or pickle self.comparator = comparator - super(PickleType, self).__init__() + super().__init__() if impl: # custom impl is not necessarily a LargeBinary subclass. @@ -2000,7 +1996,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): support a "day precision" parameter, i.e. Oracle. """ - super(Interval, self).__init__() + super().__init__() self.native = native self.second_precision = second_precision self.day_precision = day_precision @@ -3005,7 +3001,7 @@ class ARRAY( def _set_parent_with_dispatch(self, parent): """Support SchemaEventTarget""" - super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True) + super()._set_parent_with_dispatch(parent, outer=True) if isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent_with_dispatch(parent) @@ -3249,7 +3245,7 @@ class TIMESTAMP(DateTime): """ - super(TIMESTAMP, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) def get_dbapi_type(self, dbapi): return dbapi.TIMESTAMP @@ -3464,7 +3460,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): @overload def __init__( - self: "Uuid[_python_UUID]", + self: Uuid[_python_UUID], as_uuid: Literal[True] = ..., native_uuid: bool = ..., ): @@ -3472,7 +3468,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): @overload def __init__( - self: "Uuid[str]", + self: Uuid[str], as_uuid: Literal[False] = ..., native_uuid: bool = ..., ): @@ -3628,11 +3624,11 @@ class UUID(Uuid[_UUID_RETURN]): __visit_name__ = "UUID" @overload - def __init__(self: "UUID[_python_UUID]", as_uuid: Literal[True] = ...): + def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...): ... @overload - def __init__(self: "UUID[str]", as_uuid: Literal[False] = ...): + def __init__(self: UUID[str], as_uuid: Literal[False] = ...): ... def __init__(self, as_uuid: bool = True): diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 135407321..866c0ccde 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -301,9 +301,7 @@ class _CopyInternalsTraversal(HasTraversalDispatch): def visit_string_clauseelement_dict( self, attrname, parent, element, clone=_clone, **kw ): - return dict( - (key, clone(value, **kw)) for key, value in element.items() - ) + return {key: clone(value, **kw) for key, value in element.items()} def visit_setup_join_tuple( self, attrname, parent, element, clone=_clone, **kw diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index cd57ee3b6..c3768c6c6 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1399,7 +1399,7 @@ class Emulated(TypeEngineMixin): def _is_native_for_emulated( typ: Type[Union[TypeEngine[Any], TypeEngineMixin]], -) -> TypeGuard["Type[NativeForEmulated]"]: +) -> TypeGuard[Type[NativeForEmulated]]: return hasattr(typ, "adapt_emulated_to_native") @@ -1673,9 +1673,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): if TYPE_CHECKING: assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types - return super(TypeDecorator.Comparator, self).operate( - op, *other, **kwargs - ) + return super().operate(op, *other, **kwargs) def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any @@ -1683,9 +1681,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): if TYPE_CHECKING: assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types - return super(TypeDecorator.Comparator, self).reverse_operate( - op, other, **kwargs - ) + return super().reverse_operate(op, other, **kwargs) @property def comparator_factory( # type: ignore # mypy properties bug diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index ec8ea757f..14cbe2456 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -316,8 +316,7 @@ def visit_binary_product( if isinstance(element, ColumnClause): yield element for elem in element.get_children(): - for e in visit(elem): - yield e + yield from visit(elem) list(visit(expr)) visit = None # type: ignore # remove gc cycles @@ -433,12 +432,10 @@ def expand_column_list_from_order_by(collist, order_by): in the collist. """ - cols_already_present = set( - [ - col.element if col._order_by_label_element is not None else col - for col in collist - ] - ) + cols_already_present = { + col.element if col._order_by_label_element is not None else col + for col in collist + } to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by])) @@ -463,13 +460,10 @@ def clause_is_present(clause, search): def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]: if isinstance(clause, Join): - for t in tables_from_leftmost(clause.left): - yield t - for t in tables_from_leftmost(clause.right): - yield t + yield from tables_from_leftmost(clause.left) + yield from tables_from_leftmost(clause.right) elif isinstance(clause, FromGrouping): - for t in tables_from_leftmost(clause.element): - yield t + yield from tables_from_leftmost(clause.element) else: yield clause @@ -592,7 +586,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: "Row[Any]", max_chars: int = 300): + def __init__(self, row: Row[Any], max_chars: int = 300): self.row = row self.max_chars = max_chars @@ -775,7 +769,7 @@ class _repr_params(_repr_base): ) return text - def _repr_param_tuple(self, params: "Sequence[Any]") -> str: + def _repr_param_tuple(self, params: Sequence[Any]) -> str: trunc = self.trunc ( |