diff options
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/mock.py | 4 | ||||
-rw-r--r-- | test/engine/test_execute.py | 54 | ||||
-rw-r--r-- | test/orm/test_versioning.py | 2 | ||||
-rw-r--r-- | test/sql/test_returning.py | 2 |
5 files changed, 35 insertions, 36 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 257eaa18a..735113a26 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -898,11 +898,10 @@ class Connection(Connectable): elif not context._is_explicit_returning: result.close(_autoclose_connection=False) result._metadata = None - elif context.isupdate: - if context._is_implicit_returning: - context._fetch_implicit_update_returning(result) - result.close(_autoclose_connection=False) - result._metadata = None + elif context.isupdate and context._is_implicit_returning: + context._fetch_implicit_update_returning(result) + result.close(_autoclose_connection=False) + result._metadata = None elif result._metadata is None: # no results, get rowcount diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py index 650962384..fa2d477a7 100644 --- a/lib/sqlalchemy/testing/mock.py +++ b/lib/sqlalchemy/testing/mock.py @@ -4,10 +4,10 @@ from __future__ import absolute_import from ..util import py33 if py33: - from unittest.mock import MagicMock, Mock, call + from unittest.mock import MagicMock, Mock, call, patch else: try: - from mock import MagicMock, Mock, call + from mock import MagicMock, Mock, call, patch except ImportError: raise ImportError( "SQLAlchemy's test suite requires the " diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 1d2aebf97..9623c080a 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -17,9 +17,9 @@ from sqlalchemy.testing.engines import testing_engine import logging.handlers from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam from sqlalchemy.engine import result as _result, default -from sqlalchemy.engine.base import Connection, Engine +from sqlalchemy.engine.base import Engine from sqlalchemy.testing import fixtures -from sqlalchemy.testing.mock import Mock, call +from sqlalchemy.testing.mock import Mock, call, patch users, metadata, users_autoinc = None, None, None @@ -29,11 +29,11 @@ class ExecuteTest(fixtures.TestBase): global users, users_autoinc, metadata metadata = MetaData(testing.db) users = Table('users', metadata, - Column('user_id', INT, primary_key = True, autoincrement=False), + Column('user_id', INT, primary_key=True, autoincrement=False), Column('user_name', VARCHAR(20)), ) users_autoinc = Table('users_autoinc', metadata, - Column('user_id', INT, primary_key = True, + Column('user_id', INT, primary_key=True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) @@ -892,42 +892,42 @@ class ResultProxyTest(fixtures.TestBase): def test_no_rowcount_on_selects_inserts(self): """assert that rowcount is only called on deletes and updates. - This because cursor.rowcount can be expensive on some dialects - such as Firebird. + This because cursor.rowcount may can be expensive on some dialects + such as Firebird, however many dialects require it be called + before the cursor is closed. """ metadata = self.metadata engine = engines.testing_engine() - metadata.bind = engine t = Table('t1', metadata, Column('data', String(10)) ) - metadata.create_all() + metadata.create_all(engine) - class BreakRowcountMixin(object): - @property - def rowcount(self): - assert False + with patch.object(engine.dialect.execution_ctx_cls, "rowcount") as mock_rowcount: + mock_rowcount.__get__ = Mock() + engine.execute(t.insert(), + {'data': 'd1'}, + {'data': 'd2'}, + {'data': 'd3'}) - execution_ctx_cls = engine.dialect.execution_ctx_cls - engine.dialect.execution_ctx_cls = type("FakeCtx", - (BreakRowcountMixin, - execution_ctx_cls), - {}) + eq_(len(mock_rowcount.__get__.mock_calls), 0) - try: - r = t.insert().execute({'data': 'd1'}, {'data': 'd2'}, - {'data': 'd3'}) - eq_(t.select().execute().fetchall(), [('d1', ), ('d2', ), - ('d3', )]) - assert_raises(AssertionError, t.update().execute, {'data' - : 'd4'}) - assert_raises(AssertionError, t.delete().execute) - finally: - engine.dialect.execution_ctx_cls = execution_ctx_cls + eq_( + engine.execute(t.select()).fetchall(), + [('d1', ), ('d2', ), ('d3', )] + ) + eq_(len(mock_rowcount.__get__.mock_calls), 0) + + engine.execute(t.update(), {'data': 'd4'}) + + eq_(len(mock_rowcount.__get__.mock_calls), 1) + + engine.execute(t.delete()) + eq_(len(mock_rowcount.__get__.mock_calls), 2) @testing.requires.python26 diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index d8d92830f..026793c97 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -668,7 +668,7 @@ class ServerVersioningTest(fixtures.MappedTest): if hasattr(stmt, "_counter"): return stmt._counter else: - stmt._counter = str(counter.next()) + stmt._counter = str(next(counter)) return stmt._counter Table('version_table', metadata, diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 179d2d261..19f5d26c0 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -201,7 +201,7 @@ class ReturnDefaultsTest(fixtures.TablesTest): @compiles(IncDefault) def compile(element, compiler, **kw): - return str(counter.next()) + return str(next(counter)) Table("t1", metadata, Column("id", Integer, primary_key=True, test_needs_autoincrement=True), |