summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-10-10 17:15:19 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-10-10 17:15:19 -0400
commit95be42c06ff4e5f3528de42bb04dcba228ea74c2 (patch)
tree602636e859b40fe8809ff38c50162a5c9402f85d /lib/sqlalchemy
parent3a6cdff88429e047a684c0f5d6029a30d9aaa062 (diff)
downloadsqlalchemy-95be42c06ff4e5f3528de42bb04dcba228ea74c2.tar.gz
- :meth:`.Insert.from_select` now includes Python and SQL-expression
defaults if otherwise unspecified; the limitation where non- server column defaults aren't included in an INSERT FROM SELECT is now lifted and these expressions are rendered as constants into the SELECT statement.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--lib/sqlalchemy/sql/crud.py97
-rw-r--r--lib/sqlalchemy/sql/dml.py26
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py37
4 files changed, 132 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 86f00d944..a6c30b7dc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1793,7 +1793,7 @@ class SQLCompiler(Compiled):
text += " " + returning_clause
if insert_stmt.select is not None:
- text += " %s" % self.process(insert_stmt.select, **kw)
+ text += " %s" % self.process(self._insert_from_select, **kw)
elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 1c1f661d2..831d05be1 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -89,18 +89,15 @@ def _get_crud_params(compiler, stmt, **kw):
_col_bind_name, _getattr_col_key, values, kw)
if compiler.isinsert and stmt.select_names:
- # for an insert from select, we can only use names that
- # are given, so only select for those names.
- cols = (stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names)
+ _scan_insert_from_select_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
else:
- # iterate through all table columns to maintain
- # ordering, even for those cols that aren't included
- cols = stmt.table.columns
-
- _scan_cols(
- compiler, stmt, cols, parameters,
- _getattr_col_key, _col_bind_name, check_columns, values, kw)
+ _scan_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
if parameters and stmt_parameters:
check = set(parameters).intersection(
@@ -118,13 +115,17 @@ def _get_crud_params(compiler, stmt, **kw):
return values
-def _create_bind_param(compiler, col, value, required=False, name=None):
+def _create_bind_param(
+ compiler, col, value, process=True, required=False, name=None):
if name is None:
name = col.key
bindparam = elements.BindParameter(name, value,
type_=col.type, required=required)
bindparam._is_crud = True
- return bindparam._compiler_dispatch(compiler)
+ if process:
+ bindparam = bindparam._compiler_dispatch(compiler)
+ return bindparam
+
def _key_getters_for_crud_column(compiler):
if compiler.isupdate and compiler.statement._extra_froms:
@@ -162,14 +163,52 @@ def _key_getters_for_crud_column(compiler):
return _column_as_key, _getattr_col_key, _col_bind_name
+def _scan_insert_from_select_cols(
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
+
+ need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid = \
+ _get_returning_modifiers(compiler, stmt)
+
+ cols = [stmt.table.c[_column_as_key(name)]
+ for name in stmt.select_names]
+
+ compiler._insert_from_select = stmt.select
+
+ add_select_cols = []
+ if stmt.include_insert_from_select_defaults:
+ col_set = set(cols)
+ for col in stmt.table.columns:
+ if col not in col_set and col.default:
+ cols.append(col)
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ parameters.pop(col_key)
+ values.append((c, None))
+ else:
+ _append_param_insert_select_hasdefault(
+ compiler, stmt, c, add_select_cols, kw)
+
+ if add_select_cols:
+ values.extend(add_select_cols)
+ compiler._insert_from_select = compiler._insert_from_select._generate()
+ compiler._insert_from_select._raw_columns += tuple(
+ expr for col, expr in add_select_cols)
+
+
def _scan_cols(
- compiler, stmt, cols, parameters, _getattr_col_key,
- _col_bind_name, check_columns, values, kw):
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid = \
_get_returning_modifiers(compiler, stmt)
+ cols = stmt.table.columns
+
for c in cols:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
@@ -196,7 +235,8 @@ def _scan_cols(
elif c.default is not None:
_append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults, values, kw)
+ compiler, stmt, c, implicit_return_defaults,
+ values, kw)
elif c.server_default is not None:
if implicit_return_defaults and \
@@ -299,10 +339,8 @@ def _append_param_insert_hasdefault(
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
- values.append(
- (c, compiler.process(
- c.default.arg.self_group(), **kw))
- )
+ proc = compiler.process(c.default.arg.self_group(), **kw)
+ values.append((c, proc))
if implicit_return_defaults and \
c in implicit_return_defaults:
@@ -317,6 +355,25 @@ def _append_param_insert_hasdefault(
compiler.prefetch.append(c)
+def _append_param_insert_select_hasdefault(
+ compiler, stmt, c, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = c.default
+ values.append((c, proc))
+ elif c.default.is_clause_element:
+ proc = c.default.arg.self_group()
+ values.append((c, proc))
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None, process=False))
+ )
+ compiler.prefetch.append(c)
+
+
def _append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw):
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 1934d0776..9f2ce7ce3 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -475,6 +475,7 @@ class Insert(ValuesBase):
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self.select = self.select_names = None
+ self.include_insert_from_select_defaults = False
self.inline = inline
self._returning = returning
self._validate_dialect_kwargs(dialect_kw)
@@ -487,7 +488,7 @@ class Insert(ValuesBase):
return ()
@_generative
- def from_select(self, names, select):
+ def from_select(self, names, select, include_defaults=True):
"""Return a new :class:`.Insert` construct which represents
an ``INSERT...FROM SELECT`` statement.
@@ -506,6 +507,21 @@ class Insert(ValuesBase):
is not checked before passing along to the database, the database
would normally raise an exception if these column lists don't
correspond.
+ :param include_defaults: if True, non-server default values and
+ SQL expressions as specified on :class:`.Column` objects
+ (as documented in :ref:`metadata_defaults_toplevel`) not
+ otherwise specified in the list of names will be rendered
+ into the INSERT and SELECT statements, so that these values are also
+ included in the data to be inserted.
+
+ .. note:: A Python-side default that uses a Python callable function
+ will only be invoked **once** for the whole statement, and **not
+ per row**.
+
+ .. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders
+ Python-side and SQL expression column defaults into the
+ SELECT statement for columns otherwise not included in the
+ list of column names.
.. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
implies that the :paramref:`.insert.inline` flag is set to
@@ -514,13 +530,6 @@ class Insert(ValuesBase):
deals with an arbitrary number of rows, so the
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
- .. note::
-
- A SELECT..INSERT construct in SQL has no VALUES clause. Therefore
- :class:`.Column` objects which utilize Python-side defaults
- (e.g. as described at :ref:`metadata_defaults_toplevel`)
- will **not** take effect when using :meth:`.Insert.from_select`.
-
.. versionadded:: 0.8.3
"""
@@ -533,6 +542,7 @@ class Insert(ValuesBase):
self.select_names = names
self.inline = True
+ self.include_insert_from_select_defaults = include_defaults
self.select = _interpret_as_select(select)
def _copy_internals(self, clone=_clone, **kw):
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index 92d3d93e5..c197145c7 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -4,7 +4,7 @@ from .. import exclusions
from ..assertions import eq_
from .. import engines
-from sqlalchemy import Integer, String, select, util
+from sqlalchemy import Integer, String, select, literal_column
from ..schema import Table, Column
@@ -90,6 +90,13 @@ class InsertBehaviorTest(fixtures.TablesTest):
Column('id', Integer, primary_key=True, autoincrement=False),
Column('data', String(50))
)
+ Table('includes_defaults', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('data', String(50)),
+ Column('x', Integer, default=5),
+ Column('y', Integer,
+ default=literal_column("2", type_=Integer) + 2))
def test_autoclose_on_insert(self):
if requirements.returning.enabled:
@@ -158,6 +165,34 @@ class InsertBehaviorTest(fixtures.TablesTest):
("data3", ), ("data3", )]
)
+ @requirements.insert_from_select
+ def test_insert_from_select_with_defaults(self):
+ table = self.tables.includes_defaults
+ config.db.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ]
+ )
+
+ config.db.execute(
+ table.insert(inline=True).
+ from_select(("id", "data",),
+ select([table.c.id + 5, table.c.data]).
+ where(table.c.data.in_(["data2", "data3"]))
+ ),
+ )
+
+ eq_(
+ config.db.execute(
+ select([table]).order_by(table.c.data)
+ ).fetchall(),
+ [(1, 'data1', 5, 4), (2, 'data2', 5, 4),
+ (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
+ )
+
class ReturningTest(fixtures.TablesTest):
run_create_tables = 'each'