summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-01-15 16:42:29 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2011-01-15 16:42:29 -0500
commitdff5a404e489d5215da5aa30870b78aca8423de5 (patch)
tree39f6ae1949762d269f6a73e0315f0fa6910ceacd
parentfc0ffac24155931c2db10d1a469e1f7898268e45 (diff)
downloadsqlalchemy-dff5a404e489d5215da5aa30870b78aca8423de5.tar.gz
- getting slightly more consistent behavior for the edge case of pk columns
with server default - autoincrement is now false with any server_default, so these all return None, applies consistency to [ticket:2020], [ticket:2021]. if prefetch is desired a "default" should be used instead of server_default.
-rw-r--r--doc/build/core/expression_api.rst2
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py19
-rw-r--r--lib/sqlalchemy/engine/base.py20
-rw-r--r--lib/sqlalchemy/engine/default.py5
-rw-r--r--lib/sqlalchemy/engine/reflection.py6
-rw-r--r--lib/sqlalchemy/schema.py13
-rw-r--r--lib/sqlalchemy/sql/compiler.py14
-rw-r--r--lib/sqlalchemy/sql/expression.py15
-rw-r--r--test/sql/test_defaults.py145
10 files changed, 197 insertions, 46 deletions
diff --git a/doc/build/core/expression_api.rst b/doc/build/core/expression_api.rst
index e907b2535..88c0840ac 100644
--- a/doc/build/core/expression_api.rst
+++ b/doc/build/core/expression_api.rst
@@ -185,7 +185,7 @@ Classes
:show-inheritance:
.. autoclass:: Insert
- :members: prefix_with, values
+ :members: prefix_with, values, returning
:show-inheritance:
.. autoclass:: Join
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 7ed4ca07e..63ad37ce9 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -594,9 +594,9 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
class OracleExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
- return int(self._execute_scalar("SELECT " +
+ return self._execute_scalar("SELECT " +
self.dialect.identifier_preparer.format_sequence(seq) +
- ".nextval FROM DUAL"), type_)
+ ".nextval FROM DUAL", type_)
class OracleDialect(default.DefaultDialect):
name = 'oracle'
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 7c712e8aa..a8fb4e51a 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -508,13 +508,15 @@ class PGDDLCompiler(compiler.DDLCompiler):
colspec = self.preparer.format_column(column)
type_affinity = column.type._type_affinity
if column.primary_key and \
- len(column.foreign_keys)==0 and \
- column.autoincrement and \
- issubclass(type_affinity, sqltypes.Integer) and \
+ column is column.table._autoincrement_column and \
not issubclass(type_affinity, sqltypes.SmallInteger) and \
- (column.default is None or
- (isinstance(column.default, schema.Sequence) and
- column.default.optional)):
+ (
+ column.default is None or
+ (
+ isinstance(column.default, schema.Sequence) and
+ column.default.optional
+ )
+ ):
if issubclass(type_affinity, sqltypes.BigInteger):
colspec += " BIGSERIAL"
else:
@@ -689,7 +691,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return None
def get_insert_default(self, column):
- if column.primary_key:
+ if column.primary_key and column is column.table._autoincrement_column:
if (isinstance(column.server_default, schema.DefaultClause) and
column.server_default.arg is not None):
@@ -697,8 +699,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return self._execute_scalar("select %s" %
column.server_default.arg, column.type)
- elif column is column.table._autoincrement_column \
- and (column.default is None or
+ elif (column.default is None or
(isinstance(column.default, schema.Sequence) and
column.default.optional)):
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 3bdcad2ac..9eb1b8b40 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -2460,9 +2460,23 @@ class ResultProxy(object):
@util.memoized_property
def inserted_primary_key(self):
"""Return the primary key for the row just inserted.
-
- This only applies to single row insert() constructs which
- did not explicitly specify returning().
+
+ The return value is a list of scalar values
+ corresponding to the list of primary key columns
+ in the target table.
+
+ This only applies to single row :func:`.insert`
+ constructs which did not explicitly specify
+ :meth:`.Insert.returning`.
+
+ Note that primary key columns which specify a
+ server_default clause,
+ or otherwise do not qualify as "autoincrement"
+ columns (see the notes at :class:`.Column`), and were
+ generated using the database-side default, will
+ appear in this list as ``None`` unless the backend
+ supports "returning" and the insert statement executed
+ with the "implicit returning" enabled.
"""
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index e21ec1c40..da6ed12a6 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -101,7 +101,7 @@ class DefaultDialect(base.Dialect):
if not getattr(self, 'ported_sqla_06', True):
util.warn(
- "The %s dialect is not yet ported to SQLAlchemy 0.6" %
+ "The %s dialect is not yet ported to SQLAlchemy 0.6/0.7" %
self.name)
self.convert_unicode = convert_unicode
@@ -625,7 +625,8 @@ class DefaultExecutionContext(base.ExecutionContext):
return self.dialect.supports_sane_multi_rowcount
def post_insert(self):
- if self.dialect.postfetch_lastrowid and \
+ if not self._is_implicit_returning and \
+ self.dialect.postfetch_lastrowid and \
(not self.inserted_primary_key or \
None in self.inserted_primary_key):
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
index cf254cba6..00b2fd1bf 100644
--- a/lib/sqlalchemy/engine/reflection.py
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -398,7 +398,11 @@ class Inspector(object):
if col_d.get('default') is not None:
# the "default" value is assumed to be a literal SQL expression,
# so is wrapped in text() so that no quoting occurs on re-issuance.
- colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
+ colargs.append(
+ sa_schema.DefaultClause(
+ sql.text(col_d['default']), _reflected=True
+ )
+ )
if 'sequence' in col_d:
# TODO: mssql, maxdb and sybase are using this.
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index a530a1a7a..26f607512 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -325,11 +325,8 @@ class Table(SchemaItem, expression.TableClause):
if col.autoincrement and \
issubclass(col.type._type_affinity, types.Integer) and \
not col.foreign_keys and \
- isinstance(col.default, (type(None), Sequence)):
- # don't look at server_default here since different backends may
- # or may not have a server_default, e.g. postgresql reflected
- # SERIAL cols will have a DefaultClause here but are still
- # autoincrement.
+ isinstance(col.default, (type(None), Sequence)) and \
+ (col.server_default is None or col.server_default.reflected):
return col
@property
@@ -1231,6 +1228,7 @@ class DefaultGenerator(SchemaItem):
__visit_name__ = 'default_generator'
is_sequence = False
+ is_server_default = False
def __init__(self, for_update=False):
self.for_update = for_update
@@ -1423,6 +1421,8 @@ class FetchedValue(object):
INSERT.
"""
+ is_server_default = True
+ reflected = False
def __init__(self, for_update=False):
self.for_update = for_update
@@ -1460,12 +1460,13 @@ class DefaultClause(FetchedValue):
"""
- def __init__(self, arg, for_update=False):
+ def __init__(self, arg, for_update=False, _reflected=False):
util.assert_arg_type(arg, (basestring,
expression.ClauseElement,
expression._TextClause), 'arg')
super(DefaultClause, self).__init__(for_update)
self.arg = arg
+ self.reflected = _reflected
def __repr__(self):
return "DefaultClause(%r, for_update=%r)" % \
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ce98dfb83..d906bf5d4 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1102,18 +1102,16 @@ class SQLCompiler(engine.Compiled):
else:
self.returning.append(c)
else:
- if (
- c.default is not None and \
- (
- self.dialect.supports_sequences or
- not c.default.is_sequence
- )
- ) or \
- self.dialect.preexecute_autoincrement_sequences:
+ if c.default is not None or \
+ c is stmt.table._autoincrement_column and (
+ self.dialect.supports_sequences or
+ self.dialect.preexecute_autoincrement_sequences
+ ):
values.append(
(c, self._create_crud_bind_param(c, None))
)
+
self.prefetch.append(c)
elif c.default is not None:
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index ede194f7c..6a368b8c0 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -4436,7 +4436,7 @@ class Select(_SelectBase):
self._bind = bind
bind = property(bind, _set_bind)
-class _UpdateBase(Executable, ClauseElement):
+class UpdateBase(Executable, ClauseElement):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
__visit_name__ = 'update_base'
@@ -4513,7 +4513,8 @@ class _UpdateBase(Executable, ClauseElement):
"""
self._returning = cols
-class _ValuesBase(_UpdateBase):
+class ValuesBase(UpdateBase):
+ """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs."""
__visit_name__ = 'values_base'
@@ -4548,7 +4549,7 @@ class _ValuesBase(_UpdateBase):
self.parameters.update(self._process_colparams(v))
self.parameters.update(kwargs)
-class Insert(_ValuesBase):
+class Insert(ValuesBase):
"""Represent an INSERT construct.
The :class:`Insert` object is created using the :func:`insert()` function.
@@ -4566,7 +4567,7 @@ class Insert(_ValuesBase):
prefixes=None,
returning=None,
**kwargs):
- _ValuesBase.__init__(self, table, values)
+ ValuesBase.__init__(self, table, values)
self._bind = bind
self.select = None
self.inline = inline
@@ -4598,7 +4599,7 @@ class Insert(_ValuesBase):
clause = _literal_as_text(clause)
self._prefixes = self._prefixes + (clause,)
-class Update(_ValuesBase):
+class Update(ValuesBase):
"""Represent an Update construct.
The :class:`Update` object is created using the :func:`update()` function.
@@ -4614,7 +4615,7 @@ class Update(_ValuesBase):
bind=None,
returning=None,
**kwargs):
- _ValuesBase.__init__(self, table, values)
+ ValuesBase.__init__(self, table, values)
self._bind = bind
self._returning = returning
if whereclause is not None:
@@ -4650,7 +4651,7 @@ class Update(_ValuesBase):
self._whereclause = _literal_as_text(whereclause)
-class Delete(_UpdateBase):
+class Delete(UpdateBase):
"""Represent a DELETE construct.
The :class:`Delete` object is created using the :func:`delete()` function.
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index 7822e487c..0d099a786 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -5,9 +5,9 @@ from sqlalchemy.sql import select, text, literal_column
import sqlalchemy as sa
from test.lib import testing, engines
from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean, exc,\
- Sequence, Column, func, literal
+ Sequence, func, literal
from sqlalchemy.types import TypeDecorator
-from test.lib.schema import Table
+from test.lib.schema import Table, Column
from test.lib.testing import eq_
from test.sql import _base
@@ -704,9 +704,13 @@ class SpecialTypePKTest(testing.TestBase):
class MyInteger(TypeDecorator):
impl = Integer
def process_bind_param(self, value, dialect):
+ if value is None:
+ return None
return int(value[4:])
def process_result_value(self, value, dialect):
+ if value is None:
+ return None
return "INT_%d" % value
cls.MyInteger = MyInteger
@@ -715,6 +719,8 @@ class SpecialTypePKTest(testing.TestBase):
def _run_test(self, *arg, **kw):
implicit_returning = kw.pop('implicit_returning', True)
kw['primary_key'] = True
+ if kw.get('autoincrement', True):
+ kw['test_needs_autoincrement'] = True
t = Table('x', metadata,
Column('y', self.MyInteger, *arg, **kw),
Column('data', Integer),
@@ -723,7 +729,12 @@ class SpecialTypePKTest(testing.TestBase):
t.create()
r = t.insert().values(data=5).execute()
- eq_(r.inserted_primary_key, ['INT_1'])
+
+ # we don't pre-fetch 'server_default'.
+ if 'server_default' in kw and (not testing.db.dialect.implicit_returning or not implicit_returning):
+ eq_(r.inserted_primary_key, [None])
+ else:
+ eq_(r.inserted_primary_key, ['INT_1'])
r.close()
eq_(
@@ -745,13 +756,9 @@ class SpecialTypePKTest(testing.TestBase):
def test_sequence(self):
self._run_test(Sequence('foo_seq'))
- @testing.fails_on('mysql', "Pending [ticket:2021]")
def test_server_default(self):
- # note that the MySQL dialect has to not render AUTOINCREMENT on this one
self._run_test(server_default='1',)
- @testing.fails_on('mysql', "Pending [ticket:2021]")
- @testing.fails_on('sqlite', "Pending [ticket:2021]")
def test_server_default_no_autoincrement(self):
self._run_test(server_default='1', autoincrement=False)
@@ -767,4 +774,128 @@ class SpecialTypePKTest(testing.TestBase):
def test_server_default_no_implicit_returning(self):
self._run_test(server_default='1', autoincrement=False)
+class ServerDefaultsOnPKTest(testing.TestBase):
+ @testing.provide_metadata
+ def test_string_default_none_on_insert(self):
+ """Test that without implicit returning, we return None for
+ a string server default.
+
+ That is, we don't want to attempt to pre-execute "server_default"
+ generically - the user should use a Python side-default for a case
+ like this. Testing that all backends do the same thing here.
+
+ """
+ t = Table('x', metadata,
+ Column('y', String(10), server_default='key_one', primary_key=True),
+ Column('data', String(10)),
+ implicit_returning=False
+ )
+ metadata.create_all()
+ r = t.insert().execute(data='data')
+ eq_(r.inserted_primary_key, [None])
+ eq_(
+ t.select().execute().fetchall(),
+ [('key_one', 'data')]
+ )
+
+ @testing.requires.returning
+ @testing.provide_metadata
+ def test_string_default_on_insert_with_returning(self):
+ """With implicit_returning, we get a string PK default back no problem."""
+ t = Table('x', metadata,
+ Column('y', String(10), server_default='key_one', primary_key=True),
+ Column('data', String(10))
+ )
+ metadata.create_all()
+ r = t.insert().execute(data='data')
+ eq_(r.inserted_primary_key, ['key_one'])
+ eq_(
+ t.select().execute().fetchall(),
+ [('key_one', 'data')]
+ )
+
+ @testing.provide_metadata
+ def test_int_default_none_on_insert(self):
+ t = Table('x', metadata,
+ Column('y', Integer,
+ server_default='5', primary_key=True),
+ Column('data', String(10)),
+ implicit_returning=False
+ )
+ assert t._autoincrement_column is None
+ metadata.create_all()
+ r = t.insert().execute(data='data')
+ eq_(r.inserted_primary_key, [None])
+ if testing.against('sqlite'):
+ eq_(
+ t.select().execute().fetchall(),
+ [(1, 'data')]
+ )
+ else:
+ eq_(
+ t.select().execute().fetchall(),
+ [(5, 'data')]
+ )
+ @testing.fails_on('firebird', "col comes back as autoincrement")
+ @testing.fails_on('sqlite', "col comes back as autoincrement")
+ @testing.fails_on('oracle', "col comes back as autoincrement")
+ @testing.provide_metadata
+ def test_autoincrement_reflected_from_server_default(self):
+ t = Table('x', metadata,
+ Column('y', Integer,
+ server_default='5', primary_key=True),
+ Column('data', String(10)),
+ implicit_returning=False
+ )
+ assert t._autoincrement_column is None
+ metadata.create_all()
+
+ m2 = MetaData(metadata.bind)
+ t2 = Table('x', m2, autoload=True, implicit_returning=False)
+ assert t2._autoincrement_column is None
+
+ @testing.fails_on('firebird', "attempts to insert None")
+ @testing.fails_on('sqlite', "returns a value")
+ @testing.provide_metadata
+ def test_int_default_none_on_insert_reflected(self):
+ t = Table('x', metadata,
+ Column('y', Integer,
+ server_default='5', primary_key=True),
+ Column('data', String(10)),
+ implicit_returning=False
+ )
+ metadata.create_all()
+
+ m2 = MetaData(metadata.bind)
+ t2 = Table('x', m2, autoload=True, implicit_returning=False)
+
+ r = t2.insert().execute(data='data')
+ eq_(r.inserted_primary_key, [None])
+ if testing.against('sqlite'):
+ eq_(
+ t2.select().execute().fetchall(),
+ [(1, 'data')]
+ )
+ else:
+ eq_(
+ t2.select().execute().fetchall(),
+ [(5, 'data')]
+ )
+
+ @testing.requires.returning
+ @testing.provide_metadata
+ def test_int_default_on_insert_with_returning(self):
+ t = Table('x', metadata,
+ Column('y', Integer,
+ server_default='5', primary_key=True),
+ Column('data', String(10))
+ )
+
+ metadata.create_all()
+ r = t.insert().execute(data='data')
+ eq_(r.inserted_primary_key, [5])
+ eq_(
+ t.select().execute().fetchall(),
+ [(5, 'data')]
+ )