summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py416
1 files changed, 237 insertions, 179 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 028abc4c2..d7c2518fe 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -24,13 +24,11 @@ import weakref
from .. import event
AUTOCOMMIT_REGEXP = re.compile(
- r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
- re.I | re.UNICODE)
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE
+)
# When we're handed literal SQL, ensure it's a SELECT query
-SERVER_SIDE_CURSOR_RE = re.compile(
- r'\s*SELECT',
- re.I | re.UNICODE)
+SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
class DefaultDialect(interfaces.Dialect):
@@ -68,16 +66,18 @@ class DefaultDialect(interfaces.Dialect):
supports_simple_order_by_label = True
- engine_config_types = util.immutabledict([
- ('convert_unicode', util.bool_or_str('force')),
- ('pool_timeout', util.asint),
- ('echo', util.bool_or_str('debug')),
- ('echo_pool', util.bool_or_str('debug')),
- ('pool_recycle', util.asint),
- ('pool_size', util.asint),
- ('max_overflow', util.asint),
- ('pool_threadlocal', util.asbool),
- ])
+ engine_config_types = util.immutabledict(
+ [
+ ("convert_unicode", util.bool_or_str("force")),
+ ("pool_timeout", util.asint),
+ ("echo", util.bool_or_str("debug")),
+ ("echo_pool", util.bool_or_str("debug")),
+ ("pool_recycle", util.asint),
+ ("pool_size", util.asint),
+ ("max_overflow", util.asint),
+ ("pool_threadlocal", util.asbool),
+ ]
+ )
# if the NUMERIC type
# returns decimal.Decimal.
@@ -93,9 +93,9 @@ class DefaultDialect(interfaces.Dialect):
supports_unicode_statements = False
supports_unicode_binds = False
returns_unicode_strings = False
- description_encoding = 'use_encoding'
+ description_encoding = "use_encoding"
- name = 'default'
+ name = "default"
# length at which to truncate
# any identifier.
@@ -111,7 +111,7 @@ class DefaultDialect(interfaces.Dialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
colspecs = {}
- default_paramstyle = 'named'
+ default_paramstyle = "named"
supports_default_values = False
supports_empty_insert = True
supports_multivalues_insert = False
@@ -175,19 +175,26 @@ class DefaultDialect(interfaces.Dialect):
"""
- def __init__(self, convert_unicode=False,
- encoding='utf-8', paramstyle=None, dbapi=None,
- implicit_returning=None,
- supports_right_nested_joins=None,
- case_sensitive=True,
- supports_native_boolean=None,
- empty_in_strategy='static',
- label_length=None, **kwargs):
-
- if not getattr(self, 'ported_sqla_06', True):
+ def __init__(
+ self,
+ convert_unicode=False,
+ encoding="utf-8",
+ paramstyle=None,
+ dbapi=None,
+ implicit_returning=None,
+ supports_right_nested_joins=None,
+ case_sensitive=True,
+ supports_native_boolean=None,
+ empty_in_strategy="static",
+ label_length=None,
+ **kwargs
+ ):
+
+ if not getattr(self, "ported_sqla_06", True):
util.warn(
- "The %s dialect is not yet ported to the 0.6 format" %
- self.name)
+ "The %s dialect is not yet ported to the 0.6 format"
+ % self.name
+ )
self.convert_unicode = convert_unicode
self.encoding = encoding
@@ -202,7 +209,7 @@ class DefaultDialect(interfaces.Dialect):
self.paramstyle = self.default_paramstyle
if implicit_returning is not None:
self.implicit_returning = implicit_returning
- self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
+ self.positional = self.paramstyle in ("qmark", "format", "numeric")
self.identifier_preparer = self.preparer(self)
self.type_compiler = self.type_compiler(self)
if supports_right_nested_joins is not None:
@@ -212,33 +219,33 @@ class DefaultDialect(interfaces.Dialect):
self.case_sensitive = case_sensitive
self.empty_in_strategy = empty_in_strategy
- if empty_in_strategy == 'static':
+ if empty_in_strategy == "static":
self._use_static_in = True
- elif empty_in_strategy in ('dynamic', 'dynamic_warn'):
+ elif empty_in_strategy in ("dynamic", "dynamic_warn"):
self._use_static_in = False
- self._warn_on_empty_in = empty_in_strategy == 'dynamic_warn'
+ self._warn_on_empty_in = empty_in_strategy == "dynamic_warn"
else:
raise exc.ArgumentError(
"empty_in_strategy may be 'static', "
- "'dynamic', or 'dynamic_warn'")
+ "'dynamic', or 'dynamic_warn'"
+ )
if label_length and label_length > self.max_identifier_length:
raise exc.ArgumentError(
"Label length of %d is greater than this dialect's"
- " maximum identifier length of %d" %
- (label_length, self.max_identifier_length))
+ " maximum identifier length of %d"
+ % (label_length, self.max_identifier_length)
+ )
self.label_length = label_length
- if self.description_encoding == 'use_encoding':
- self._description_decoder = \
- processors.to_unicode_processor_factory(
- encoding
- )
+ if self.description_encoding == "use_encoding":
+ self._description_decoder = processors.to_unicode_processor_factory(
+ encoding
+ )
elif self.description_encoding is not None:
- self._description_decoder = \
- processors.to_unicode_processor_factory(
- self.description_encoding
- )
+ self._description_decoder = processors.to_unicode_processor_factory(
+ self.description_encoding
+ )
self._encoder = codecs.getencoder(self.encoding)
self._decoder = processors.to_unicode_processor_factory(self.encoding)
@@ -256,30 +263,35 @@ class DefaultDialect(interfaces.Dialect):
@classmethod
def get_pool_class(cls, url):
- return getattr(cls, 'poolclass', pool.QueuePool)
+ return getattr(cls, "poolclass", pool.QueuePool)
def initialize(self, connection):
try:
- self.server_version_info = \
- self._get_server_version_info(connection)
+ self.server_version_info = self._get_server_version_info(
+ connection
+ )
except NotImplementedError:
self.server_version_info = None
try:
- self.default_schema_name = \
- self._get_default_schema_name(connection)
+ self.default_schema_name = self._get_default_schema_name(
+ connection
+ )
except NotImplementedError:
self.default_schema_name = None
try:
- self.default_isolation_level = \
- self.get_isolation_level(connection.connection)
+ self.default_isolation_level = self.get_isolation_level(
+ connection.connection
+ )
except NotImplementedError:
self.default_isolation_level = None
self.returns_unicode_strings = self._check_unicode_returns(connection)
- if self.description_encoding is not None and \
- self._check_unicode_description(connection):
+ if (
+ self.description_encoding is not None
+ and self._check_unicode_description(connection)
+ ):
self._description_decoder = self.description_encoding = None
self.do_rollback(connection.connection)
@@ -311,7 +323,8 @@ class DefaultDialect(interfaces.Dialect):
def check_unicode(test):
statement = cast_to(
- expression.select([test]).compile(dialect=self))
+ expression.select([test]).compile(dialect=self)
+ )
try:
cursor = connection.connection.cursor()
connection._cursor_execute(cursor, statement, parameters)
@@ -320,8 +333,10 @@ class DefaultDialect(interfaces.Dialect):
except exc.DBAPIError as de:
# note that _cursor_execute() will have closed the cursor
# if an exception is thrown.
- util.warn("Exception attempting to "
- "detect unicode returns: %r" % de)
+ util.warn(
+ "Exception attempting to "
+ "detect unicode returns: %r" % de
+ )
return False
else:
return isinstance(row[0], util.text_type)
@@ -330,13 +345,13 @@ class DefaultDialect(interfaces.Dialect):
# detect plain VARCHAR
expression.cast(
expression.literal_column("'test plain returns'"),
- sqltypes.VARCHAR(60)
+ sqltypes.VARCHAR(60),
),
# detect if there's an NVARCHAR type with different behavior
# available
expression.cast(
expression.literal_column("'test unicode returns'"),
- sqltypes.Unicode(60)
+ sqltypes.Unicode(60),
),
]
@@ -364,9 +379,9 @@ class DefaultDialect(interfaces.Dialect):
try:
cursor.execute(
cast_to(
- expression.select([
- expression.literal_column("'x'").label("some_label")
- ]).compile(dialect=self)
+ expression.select(
+ [expression.literal_column("'x'").label("some_label")]
+ ).compile(dialect=self)
)
)
return isinstance(cursor.description[0][0], util.text_type)
@@ -385,10 +400,12 @@ class DefaultDialect(interfaces.Dialect):
return sqltypes.adapt_type(typeobj, self.colspecs)
def reflecttable(
- self, connection, table, include_columns, exclude_columns, **opts):
+ self, connection, table, include_columns, exclude_columns, **opts
+ ):
insp = reflection.Inspector.from_engine(connection)
return insp.reflecttable(
- table, include_columns, exclude_columns, **opts)
+ table, include_columns, exclude_columns, **opts
+ )
def get_pk_constraint(self, conn, table_name, schema=None, **kw):
"""Compatibility method, adapts the result of get_primary_keys()
@@ -396,16 +413,16 @@ class DefaultDialect(interfaces.Dialect):
"""
return {
- 'constrained_columns':
- self.get_primary_keys(conn, table_name,
- schema=schema, **kw)
+ "constrained_columns": self.get_primary_keys(
+ conn, table_name, schema=schema, **kw
+ )
}
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
- "Identifier '%s' exceeds maximum length of %d characters" %
- (ident, self.max_identifier_length)
+ "Identifier '%s' exceeds maximum length of %d characters"
+ % (ident, self.max_identifier_length)
)
def connect(self, *cargs, **cparams):
@@ -417,16 +434,16 @@ class DefaultDialect(interfaces.Dialect):
return [[], opts]
def set_engine_execution_options(self, engine, opts):
- if 'isolation_level' in opts:
- isolation_level = opts['isolation_level']
+ if "isolation_level" in opts:
+ isolation_level = opts["isolation_level"]
@event.listens_for(engine, "engine_connect")
def set_isolation(connection, branch):
if not branch:
self._set_connection_isolation(connection, isolation_level)
- if 'schema_translate_map' in opts:
- getter = schema._schema_getter(opts['schema_translate_map'])
+ if "schema_translate_map" in opts:
+ getter = schema._schema_getter(opts["schema_translate_map"])
engine.schema_for_object = getter
@event.listens_for(engine, "engine_connect")
@@ -434,11 +451,11 @@ class DefaultDialect(interfaces.Dialect):
connection.schema_for_object = getter
def set_connection_execution_options(self, connection, opts):
- if 'isolation_level' in opts:
- self._set_connection_isolation(connection, opts['isolation_level'])
+ if "isolation_level" in opts:
+ self._set_connection_isolation(connection, opts["isolation_level"])
- if 'schema_translate_map' in opts:
- getter = schema._schema_getter(opts['schema_translate_map'])
+ if "schema_translate_map" in opts:
+ getter = schema._schema_getter(opts["schema_translate_map"])
connection.schema_for_object = getter
def _set_connection_isolation(self, connection, level):
@@ -447,10 +464,12 @@ class DefaultDialect(interfaces.Dialect):
"Connection is already established with a Transaction; "
"setting isolation_level may implicitly rollback or commit "
"the existing transaction, or have no effect until "
- "next transaction")
+ "next transaction"
+ )
self.set_isolation_level(connection.connection, level)
- connection.connection._connection_record.\
- finalize_callback.append(self.reset_isolation_level)
+ connection.connection._connection_record.finalize_callback.append(
+ self.reset_isolation_level
+ )
def do_begin(self, dbapi_connection):
pass
@@ -593,8 +612,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self
@classmethod
- def _init_compiled(cls, dialect, connection, dbapi_connection,
- compiled, parameters):
+ def _init_compiled(
+ cls, dialect, connection, dbapi_connection, compiled, parameters
+ ):
"""Initialize execution context for a Compiled construct."""
self = cls.__new__(cls)
@@ -609,16 +629,20 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
assert compiled.can_execute
self.execution_options = compiled.execution_options.union(
- connection._execution_options)
+ connection._execution_options
+ )
self.result_column_struct = (
- compiled._result_columns, compiled._ordered_columns,
- compiled._textual_ordered_columns)
+ compiled._result_columns,
+ compiled._ordered_columns,
+ compiled._textual_ordered_columns,
+ )
self.unicode_statement = util.text_type(compiled)
if not dialect.supports_unicode_statements:
self.statement = self.unicode_statement.encode(
- self.dialect.encoding)
+ self.dialect.encoding
+ )
else:
self.statement = self.unicode_statement
@@ -630,9 +654,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if not parameters:
self.compiled_parameters = [compiled.construct_params()]
else:
- self.compiled_parameters = \
- [compiled.construct_params(m, _group_number=grp) for
- grp, m in enumerate(parameters)]
+ self.compiled_parameters = [
+ compiled.construct_params(m, _group_number=grp)
+ for grp, m in enumerate(parameters)
+ ]
self.executemany = len(parameters) > 1
@@ -642,7 +667,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.is_crud = True
self._is_explicit_returning = bool(compiled.statement._returning)
self._is_implicit_returning = bool(
- compiled.returning and not compiled.statement._returning)
+ compiled.returning and not compiled.statement._returning
+ )
if self.compiled.insert_prefetch or self.compiled.update_prefetch:
if self.executemany:
@@ -680,7 +706,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
dialect._encoder(key)[0],
processors[key](compiled_params[key])
if key in processors
- else compiled_params[key]
+ else compiled_params[key],
)
for key in compiled_params
)
@@ -690,7 +716,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
key,
processors[key](compiled_params[key])
if key in processors
- else compiled_params[key]
+ else compiled_params[key],
)
for key in compiled_params
)
@@ -708,14 +734,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
if self.executemany:
raise exc.InvalidRequestError(
- "'expanding' parameters can't be used with "
- "executemany()")
+ "'expanding' parameters can't be used with " "executemany()"
+ )
if self.compiled.positional and self.compiled._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'
raise NotImplementedError(
"'expanding' bind parameters not supported with "
- "'numeric' paramstyle at this time.")
+ "'numeric' paramstyle at this time."
+ )
self._expanded_parameters = {}
@@ -729,7 +756,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
to_update_sets = {}
for name in (
- self.compiled.positiontup if compiled.positional
+ self.compiled.positiontup
+ if compiled.positional
else self.compiled.binds
):
parameter = self.compiled.binds[name]
@@ -748,12 +776,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if not values:
to_update = to_update_sets[name] = []
- replacement_expressions[name] = (
- self.compiled.visit_empty_set_expr(
- parameter._expanding_in_types
- if parameter._expanding_in_types
- else [parameter.type]
- )
+ replacement_expressions[
+ name
+ ] = self.compiled.visit_empty_set_expr(
+ parameter._expanding_in_types
+ if parameter._expanding_in_types
+ else [parameter.type]
)
elif isinstance(values[0], (tuple, list)):
@@ -763,15 +791,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for j, value in enumerate(tuple_element, 1)
]
replacement_expressions[name] = ", ".join(
- "(%s)" % ", ".join(
- self.compiled.bindtemplate % {
- "name":
- to_update[i * len(tuple_element) + j][0]
+ "(%s)"
+ % ", ".join(
+ self.compiled.bindtemplate
+ % {
+ "name": to_update[
+ i * len(tuple_element) + j
+ ][0]
}
for j, value in enumerate(tuple_element)
)
for i, tuple_element in enumerate(values)
-
)
else:
to_update = to_update_sets[name] = [
@@ -779,20 +809,21 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for i, value in enumerate(values, 1)
]
replacement_expressions[name] = ", ".join(
- self.compiled.bindtemplate % {
- "name": key}
+ self.compiled.bindtemplate % {"name": key}
for key, value in to_update
)
compiled_params.update(to_update)
processors.update(
(key, processors[name])
- for key, value in to_update if name in processors
+ for key, value in to_update
+ if name in processors
)
if compiled.positional:
positiontup.extend(name for name, value in to_update)
self._expanded_parameters[name] = [
- expand_key for expand_key, value in to_update]
+ expand_key for expand_key, value in to_update
+ ]
elif compiled.positional:
positiontup.append(name)
@@ -800,15 +831,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return replacement_expressions[m.group(1)]
self.statement = re.sub(
- r"\[EXPANDING_(\S+)\]",
- process_expanding,
- self.statement
+ r"\[EXPANDING_(\S+)\]", process_expanding, self.statement
)
return positiontup
@classmethod
- def _init_statement(cls, dialect, connection, dbapi_connection,
- statement, parameters):
+ def _init_statement(
+ cls, dialect, connection, dbapi_connection, statement, parameters
+ ):
"""Initialize execution context for a string SQL statement."""
self = cls.__new__(cls)
@@ -836,13 +866,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for d in parameters
] or [{}]
else:
- self.parameters = [dialect.execute_sequence_format(p)
- for p in parameters]
+ self.parameters = [
+ dialect.execute_sequence_format(p) for p in parameters
+ ]
self.executemany = len(parameters) > 1
- if not dialect.supports_unicode_statements and \
- isinstance(statement, util.text_type):
+ if not dialect.supports_unicode_statements and isinstance(
+ statement, util.text_type
+ ):
self.unicode_statement = statement
self.statement = dialect._encoder(statement)[0]
else:
@@ -890,11 +922,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
@util.memoized_property
def should_autocommit(self):
- autocommit = self.execution_options.get('autocommit',
- not self.compiled and
- self.statement and
- expression.PARSE_AUTOCOMMIT
- or False)
+ autocommit = self.execution_options.get(
+ "autocommit",
+ not self.compiled
+ and self.statement
+ and expression.PARSE_AUTOCOMMIT
+ or False,
+ )
if autocommit is expression.PARSE_AUTOCOMMIT:
return self.should_autocommit_text(self.unicode_statement)
@@ -912,8 +946,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
conn = self.root_connection
- if isinstance(stmt, util.text_type) and \
- not self.dialect.supports_unicode_statements:
+ if (
+ isinstance(stmt, util.text_type)
+ and not self.dialect.supports_unicode_statements
+ ):
stmt = self.dialect._encoder(stmt)[0]
if self.dialect.positional:
@@ -926,8 +962,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if type_ is not None:
# apply type post processors to the result
proc = type_._cached_result_processor(
- self.dialect,
- self.cursor.description[0][1]
+ self.dialect, self.cursor.description[0][1]
)
if proc:
return proc(r)
@@ -945,22 +980,30 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return False
if self.dialect.server_side_cursors:
- use_server_side = \
- self.execution_options.get('stream_results', True) and (
- (self.compiled and isinstance(self.compiled.statement,
- expression.Selectable)
- or
- (
- (not self.compiled or
- isinstance(self.compiled.statement,
- expression.TextClause))
- and self.statement and SERVER_SIDE_CURSOR_RE.match(
- self.statement))
- )
+ use_server_side = self.execution_options.get(
+ "stream_results", True
+ ) and (
+ (
+ self.compiled
+ and isinstance(
+ self.compiled.statement, expression.Selectable
+ )
+ or (
+ (
+ not self.compiled
+ or isinstance(
+ self.compiled.statement, expression.TextClause
+ )
+ )
+ and self.statement
+ and SERVER_SIDE_CURSOR_RE.match(self.statement)
+ )
)
+ )
else:
- use_server_side = \
- self.execution_options.get('stream_results', False)
+ use_server_side = self.execution_options.get(
+ "stream_results", False
+ )
return use_server_side
@@ -1039,11 +1082,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self.dialect.supports_sane_multi_rowcount
def _setup_crud_result_proxy(self):
- if self.isinsert and \
- not self.executemany:
- if not self._is_implicit_returning and \
- not self.compiled.inline and \
- self.dialect.postfetch_lastrowid:
+ if self.isinsert and not self.executemany:
+ if (
+ not self._is_implicit_returning
+ and not self.compiled.inline
+ and self.dialect.postfetch_lastrowid
+ ):
self._setup_ins_pk_from_lastrowid()
@@ -1087,12 +1131,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if autoinc_col is not None:
# apply type post processors to the lastrowid
proc = autoinc_col.type._cached_result_processor(
- self.dialect, None)
+ self.dialect, None
+ )
if proc is not None:
lastrowid = proc(lastrowid)
self.inserted_primary_key = [
- lastrowid if c is autoinc_col else
- compiled_params.get(key_getter(c), None)
+ lastrowid
+ if c is autoinc_col
+ else compiled_params.get(key_getter(c), None)
for c in table.primary_key
]
else:
@@ -1108,8 +1154,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
table = self.compiled.statement.table
compiled_params = self.compiled_parameters[0]
self.inserted_primary_key = [
- compiled_params.get(key_getter(c), None)
- for c in table.primary_key
+ compiled_params.get(key_getter(c), None) for c in table.primary_key
]
def _setup_ins_pk_from_implicit_returning(self, row):
@@ -1129,11 +1174,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
]
def lastrow_has_defaults(self):
- return (self.isinsert or self.isupdate) and \
- bool(self.compiled.postfetch)
+ return (self.isinsert or self.isupdate) and bool(
+ self.compiled.postfetch
+ )
def set_input_sizes(
- self, translate=None, include_types=None, exclude_types=None):
+ self, translate=None, include_types=None, exclude_types=None
+ ):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
@@ -1143,7 +1190,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
- if not hasattr(self.compiled, 'bind_names'):
+ if not hasattr(self.compiled, "bind_names"):
return
inputsizes = {}
@@ -1153,12 +1200,18 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
dialect_impl_cls = type(dialect_impl)
dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None and (
- not exclude_types or dbtype not in exclude_types and
- dialect_impl_cls not in exclude_types
- ) and (
- not include_types or dbtype in include_types or
- dialect_impl_cls in include_types
+ if (
+ dbtype is not None
+ and (
+ not exclude_types
+ or dbtype not in exclude_types
+ and dialect_impl_cls not in exclude_types
+ )
+ and (
+ not include_types
+ or dbtype in include_types
+ or dialect_impl_cls in include_types
+ )
):
inputsizes[bindparam] = dbtype
else:
@@ -1177,14 +1230,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if dbtype is not None:
if key in self._expanded_parameters:
positional_inputsizes.extend(
- [dbtype] * len(self._expanded_parameters[key]))
+ [dbtype] * len(self._expanded_parameters[key])
+ )
else:
positional_inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*positional_inputsizes)
except BaseException as e:
self.root_connection._handle_dbapi_exception(
- e, None, None, None, self)
+ e, None, None, None, self
+ )
else:
keyword_inputsizes = {}
for bindparam, key in self.compiled.bind_names.items():
@@ -1199,8 +1254,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
key = self.dialect._encoder(key)[0]
if key in self._expanded_parameters:
keyword_inputsizes.update(
- (expand_key, dbtype) for expand_key
- in self._expanded_parameters[key]
+ (expand_key, dbtype)
+ for expand_key in self._expanded_parameters[key]
)
else:
keyword_inputsizes[key] = dbtype
@@ -1208,7 +1263,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.cursor.setinputsizes(**keyword_inputsizes)
except BaseException as e:
self.root_connection._handle_dbapi_exception(
- e, None, None, None, self)
+ e, None, None, None, self
+ )
def _exec_default(self, column, default, type_):
if default.is_sequence:
@@ -1290,10 +1346,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
except AttributeError:
raise exc.InvalidRequestError(
"get_current_parameters() can only be invoked in the "
- "context of a Python side column default function")
- if isolate_multiinsert_groups and \
- self.isinsert and \
- self.compiled.statement._has_multi_parameters:
+ "context of a Python side column default function"
+ )
+ if (
+ isolate_multiinsert_groups
+ and self.isinsert
+ and self.compiled.statement._has_multi_parameters
+ ):
if column._is_multiparam_column:
index = column.index + 1
d = {column.original.key: parameters[column.key]}
@@ -1302,8 +1361,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
index = 0
keys = self.compiled.statement.parameters[0].keys()
d.update(
- (key, parameters["%s_m%d" % (key, index)])
- for key in keys
+ (key, parameters["%s_m%d" % (key, index)]) for key in keys
)
return d
else:
@@ -1360,12 +1418,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
def _process_executesingle_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
- self.current_parameters = compiled_parameters = \
- self.compiled_parameters[0]
+ self.current_parameters = (
+ compiled_parameters
+ ) = self.compiled_parameters[0]
for c in self.compiled.insert_prefetch:
- if c.default and \
- not c.default.is_sequence and c.default.is_scalar:
+ if c.default and not c.default.is_sequence and c.default.is_scalar:
val = c.default.arg
else:
val = self.get_insert_default(c)