diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 344 |
1 files changed, 190 insertions, 154 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 6c708aa52..ebc001821 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -317,106 +317,190 @@ class DefaultExecutionContext(base.ExecutionContext): compiled = None statement = None - def __init__(self, - dialect, - connection, - compiled_sql=None, - compiled_ddl=None, - statement=None, - parameters=None): + @classmethod + def _init_ddl(cls, dialect, connection, compiled_ddl): + """Initialize execution context for a DDLElement construct.""" + self = cls.__new__(cls) self.dialect = dialect self._connection = self.root_connection = connection self.engine = connection.engine + + self.compiled = compiled = compiled_ddl + self.isddl = True + + if compiled.statement._execution_options: + self.execution_options = compiled.statement._execution_options + if connection._execution_options: + self.execution_options = self.execution_options.union( + connection._execution_options + ) + + if not dialect.supports_unicode_statements: + self.unicode_statement = unicode(compiled) + self.statement = self.unicode_statement.encode(self.dialect.encoding) + else: + self.statement = self.unicode_statement = unicode(compiled) + + self.cursor = self.create_cursor() + self.compiled_parameters = [] + + if dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [{}] + + return self - if compiled_ddl is not None: - self.compiled = compiled = compiled_ddl - self.isddl = True - - if compiled.statement._execution_options: - self.execution_options = compiled.statement._execution_options - if connection._execution_options: - self.execution_options = self.execution_options.union( + @classmethod + def _init_compiled(cls, dialect, connection, compiled, parameters): + """Initialize execution context for a Compiled construct.""" + + self = cls.__new__(cls) + self.dialect = dialect + self._connection = self.root_connection = connection + self.engine = connection.engine + + self.compiled = compiled + + if not compiled.can_execute: + raise exc.ArgumentError("Not an executable clause: %s" % compiled) + + if compiled.statement._execution_options: + self.execution_options = compiled.statement._execution_options + if connection._execution_options: + self.execution_options = self.execution_options.union( connection._execution_options ) - if not dialect.supports_unicode_statements: - self.unicode_statement = unicode(compiled) - self.statement = self.unicode_statement.encode(self.dialect.encoding) - else: - self.statement = self.unicode_statement = unicode(compiled) - - self.cursor = self.create_cursor() - self.compiled_parameters = [] - self.parameters = [self._default_params] - - elif compiled_sql is not None: - self.compiled = compiled = compiled_sql - - if not compiled.can_execute: - raise exc.ArgumentError("Not an executable clause: %s" % compiled) - - if compiled.statement._execution_options: - self.execution_options = compiled.statement._execution_options - if connection._execution_options: - self.execution_options = self.execution_options.union( - connection._execution_options - ) - - # compiled clauseelement. process bind params, process table defaults, - # track collections used by ResultProxy to target and process results - - self.processors = dict( - (key, value) for key, value in - ( (compiled.bind_names[bindparam], - bindparam.bind_processor(self.dialect)) - for bindparam in compiled.bind_names ) - if value is not None) - - self.result_map = compiled.result_map - - if not dialect.supports_unicode_statements: - self.unicode_statement = unicode(compiled) - self.statement = self.unicode_statement.encode(self.dialect.encoding) - else: - self.statement = self.unicode_statement = unicode(compiled) + # compiled clauseelement. process bind params, process table defaults, + # track collections used by ResultProxy to target and process results - self.isinsert = compiled.isinsert - self.isupdate = compiled.isupdate - self.isdelete = compiled.isdelete + self.result_map = compiled.result_map - if not parameters: - self.compiled_parameters = [compiled.construct_params()] - else: - self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for - grp,m in enumerate(parameters)] - - self.executemany = len(parameters) > 1 - - self.cursor = self.create_cursor() - if self.isinsert or self.isupdate: - self.__process_defaults() - self.parameters = self.__convert_compiled_params(self.compiled_parameters) - - elif statement is not None: - # plain text statement - if connection._execution_options: - self.execution_options = self.execution_options.union(connection._execution_options) - self.parameters = self.__encode_param_keys(parameters) + self.unicode_statement = unicode(compiled) + if not dialect.supports_unicode_statements: + self.statement = self.unicode_statement.encode(self.dialect.encoding) + else: + self.statement = self.unicode_statement + + self.isinsert = compiled.isinsert + self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete + + if not parameters: + self.compiled_parameters = [compiled.construct_params()] + else: + self.compiled_parameters = \ + [compiled.construct_params(m, _group_number=grp) for + grp,m in enumerate(parameters)] + self.executemany = len(parameters) > 1 + + self.cursor = self.create_cursor() + if self.isinsert or self.isupdate: + self.__process_defaults() - if not dialect.supports_unicode_statements and isinstance(statement, unicode): - self.unicode_statement = statement - self.statement = statement.encode(self.dialect.encoding) + processors = dict( + (key, value) for key, value in + ( (compiled.bind_names[bindparam], + bindparam.bind_processor(dialect)) + for bindparam in compiled.bind_names ) + if value is not None) + + # Convert the dictionary of bind parameter values + # into a dict or list to be sent to the DBAPI's + # execute() or executemany() method. + parameters = [] + if dialect.positional: + for compiled_params in self.compiled_parameters: + param = [] + for key in self.compiled.positiontup: + if key in processors: + param.append(processors[key](compiled_params[key])) + else: + param.append(compiled_params[key]) + parameters.append(dialect.execute_sequence_format(param)) + else: + encode = not dialect.supports_unicode_statements + for compiled_params in self.compiled_parameters: + param = {} + if encode: + encoding = dialect.encoding + for key in compiled_params: + if key in processors: + param[key.encode(encoding)] = \ + processors[key](compiled_params[key]) + else: + param[key.encode(encoding)] = compiled_params[key] + else: + for key in compiled_params: + if key in processors: + param[key] = processors[key](compiled_params[key]) + else: + param[key] = compiled_params[key] + parameters.append(param) + self.parameters = dialect.execute_sequence_format(parameters) + + return self + + @classmethod + def _init_statement(cls, dialect, connection, statement, parameters): + """Initialize execution context for a string SQL statement.""" + + self = cls.__new__(cls) + self.dialect = dialect + self._connection = self.root_connection = connection + self.engine = connection.engine + + # plain text statement + if connection._execution_options: + self.execution_options = self.execution_options.\ + union(connection._execution_options) + + if not parameters: + if self.dialect.positional: + self.parameters = [dialect.execute_sequence_format()] else: - self.statement = self.unicode_statement = statement - - self.cursor = self.create_cursor() + self.parameters = [{}] + elif isinstance(parameters[0], dialect.execute_sequence_format): + self.parameters = parameters + elif isinstance(parameters[0], dict): + if dialect.supports_unicode_statements: + self.parameters = parameters + else: + self.parameters= [ + dict((k.encode(dialect.encoding), d[k]) for k in d) + for d in parameters + ] or [{}] + else: + self.parameters = [dialect.execute_sequence_format(p) + for p in parameters] + + self.executemany = len(parameters) > 1 + + if not dialect.supports_unicode_statements and isinstance(statement, unicode): + self.unicode_statement = statement + self.statement = statement.encode(self.dialect.encoding) else: - # no statement. used for standalone ColumnDefault execution. - if connection._execution_options: - self.execution_options = self.execution_options.union(connection._execution_options) - self.cursor = self.create_cursor() + self.statement = self.unicode_statement = statement + + self.cursor = self.create_cursor() + return self + + @classmethod + def _init_default(cls, dialect, connection): + """Initialize execution context for a ColumnDefault construct.""" + + self = cls.__new__(cls) + self.dialect = dialect + self._connection = self.root_connection = connection + self.engine = connection.engine + if connection._execution_options: + self.execution_options = self.execution_options.\ + union(connection._execution_options) + self.cursor = self.create_cursor() + return self @util.memoized_property def is_crud(self): @@ -446,13 +530,6 @@ class DefaultExecutionContext(base.ExecutionContext): bool(self.compiled.returning) and \ not self.compiled.statement._returning - @util.memoized_property - def _default_params(self): - if self.dialect.positional: - return self.dialect.execute_sequence_format() - else: - return {} - def _execute_scalar(self, stmt): """Execute a string statement on the current cursor, returning a scalar result. @@ -466,70 +543,20 @@ class DefaultExecutionContext(base.ExecutionContext): conn = self._connection if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements: stmt = stmt.encode(self.dialect.encoding) - conn._cursor_execute(self.cursor, stmt, self._default_params) + + if self.dialect.positional: + default_params = self.dialect.execute_sequence_format() + else: + default_params = {} + + conn._cursor_execute(self.cursor, stmt, default_params) return self.cursor.fetchone()[0] @property def connection(self): return self._connection._branch() - def __encode_param_keys(self, params): - """Apply string encoding to the keys of dictionary-based bind parameters. - This is only used executing textual, non-compiled SQL expressions. - - """ - - if not params: - return [self._default_params] - elif isinstance(params[0], self.dialect.execute_sequence_format): - return params - elif isinstance(params[0], dict): - if self.dialect.supports_unicode_statements: - return params - else: - def proc(d): - return dict((k.encode(self.dialect.encoding), d[k]) for k in d) - return [proc(d) for d in params] or [{}] - else: - return [self.dialect.execute_sequence_format(p) for p in params] - - - def __convert_compiled_params(self, compiled_parameters): - """Convert the dictionary of bind parameter values into a dict or list - to be sent to the DBAPI's execute() or executemany() method. - """ - - processors = self.processors - parameters = [] - if self.dialect.positional: - for compiled_params in compiled_parameters: - param = [] - for key in self.compiled.positiontup: - if key in processors: - param.append(processors[key](compiled_params[key])) - else: - param.append(compiled_params[key]) - parameters.append(self.dialect.execute_sequence_format(param)) - else: - encode = not self.dialect.supports_unicode_statements - for compiled_params in compiled_parameters: - param = {} - if encode: - encoding = self.dialect.encoding - for key in compiled_params: - if key in processors: - param[key.encode(encoding)] = processors[key](compiled_params[key]) - else: - param[key.encode(encoding)] = compiled_params[key] - else: - for key in compiled_params: - if key in processors: - param[key] = processors[key](compiled_params[key]) - else: - param[key] = compiled_params[key] - parameters.append(param) - return self.dialect.execute_sequence_format(parameters) def should_autocommit_text(self, statement): return AUTOCOMMIT_REGEXP.match(statement) @@ -624,6 +651,10 @@ class DefaultExecutionContext(base.ExecutionContext): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. + + This method only called by those dialects which require it, + currently cx_oracle. + """ if not hasattr(self.compiled, 'bind_names'): @@ -696,7 +727,8 @@ class DefaultExecutionContext(base.ExecutionContext): scalar_defaults = {} # pre-determine scalar Python-side defaults - # to avoid many calls of get_insert_default()/get_update_default() + # to avoid many calls of get_insert_default()/ + # get_update_default() for c in self.compiled.prefetch: if self.isinsert and c.default and c.default.is_scalar: scalar_defaults[c] = c.default.arg @@ -717,7 +749,8 @@ class DefaultExecutionContext(base.ExecutionContext): del self.current_parameters else: - self.current_parameters = compiled_parameters = self.compiled_parameters[0] + self.current_parameters = compiled_parameters = \ + self.compiled_parameters[0] for c in self.compiled.prefetch: if self.isinsert: @@ -730,8 +763,11 @@ class DefaultExecutionContext(base.ExecutionContext): del self.current_parameters if self.isinsert: - self._inserted_primary_key = [compiled_parameters.get(c.key, None) - for c in self.compiled.statement.table.primary_key] + self._inserted_primary_key = [ + compiled_parameters.get(c.key, None) + for c in self.compiled.\ + statement.table.primary_key + ] self._last_inserted_params = compiled_parameters else: self._last_updated_params = compiled_parameters |