summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/engine/base.py9
-rw-r--r--lib/sqlalchemy/testing/mock.py4
-rw-r--r--test/engine/test_execute.py54
-rw-r--r--test/orm/test_versioning.py2
-rw-r--r--test/sql/test_returning.py2
5 files changed, 35 insertions, 36 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 257eaa18a..735113a26 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -898,11 +898,10 @@ class Connection(Connectable):
elif not context._is_explicit_returning:
result.close(_autoclose_connection=False)
result._metadata = None
- elif context.isupdate:
- if context._is_implicit_returning:
- context._fetch_implicit_update_returning(result)
- result.close(_autoclose_connection=False)
- result._metadata = None
+ elif context.isupdate and context._is_implicit_returning:
+ context._fetch_implicit_update_returning(result)
+ result.close(_autoclose_connection=False)
+ result._metadata = None
elif result._metadata is None:
# no results, get rowcount
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
index 650962384..fa2d477a7 100644
--- a/lib/sqlalchemy/testing/mock.py
+++ b/lib/sqlalchemy/testing/mock.py
@@ -4,10 +4,10 @@ from __future__ import absolute_import
from ..util import py33
if py33:
- from unittest.mock import MagicMock, Mock, call
+ from unittest.mock import MagicMock, Mock, call, patch
else:
try:
- from mock import MagicMock, Mock, call
+ from mock import MagicMock, Mock, call, patch
except ImportError:
raise ImportError(
"SQLAlchemy's test suite requires the "
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 1d2aebf97..9623c080a 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -17,9 +17,9 @@ from sqlalchemy.testing.engines import testing_engine
import logging.handlers
from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
from sqlalchemy.engine import result as _result, default
-from sqlalchemy.engine.base import Connection, Engine
+from sqlalchemy.engine.base import Engine
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.mock import Mock, call
+from sqlalchemy.testing.mock import Mock, call, patch
users, metadata, users_autoinc = None, None, None
@@ -29,11 +29,11 @@ class ExecuteTest(fixtures.TestBase):
global users, users_autoinc, metadata
metadata = MetaData(testing.db)
users = Table('users', metadata,
- Column('user_id', INT, primary_key = True, autoincrement=False),
+ Column('user_id', INT, primary_key=True, autoincrement=False),
Column('user_name', VARCHAR(20)),
)
users_autoinc = Table('users_autoinc', metadata,
- Column('user_id', INT, primary_key = True,
+ Column('user_id', INT, primary_key=True,
test_needs_autoincrement=True),
Column('user_name', VARCHAR(20)),
)
@@ -892,42 +892,42 @@ class ResultProxyTest(fixtures.TestBase):
def test_no_rowcount_on_selects_inserts(self):
"""assert that rowcount is only called on deletes and updates.
- This because cursor.rowcount can be expensive on some dialects
- such as Firebird.
+ This because cursor.rowcount may can be expensive on some dialects
+ such as Firebird, however many dialects require it be called
+ before the cursor is closed.
"""
metadata = self.metadata
engine = engines.testing_engine()
- metadata.bind = engine
t = Table('t1', metadata,
Column('data', String(10))
)
- metadata.create_all()
+ metadata.create_all(engine)
- class BreakRowcountMixin(object):
- @property
- def rowcount(self):
- assert False
+ with patch.object(engine.dialect.execution_ctx_cls, "rowcount") as mock_rowcount:
+ mock_rowcount.__get__ = Mock()
+ engine.execute(t.insert(),
+ {'data': 'd1'},
+ {'data': 'd2'},
+ {'data': 'd3'})
- execution_ctx_cls = engine.dialect.execution_ctx_cls
- engine.dialect.execution_ctx_cls = type("FakeCtx",
- (BreakRowcountMixin,
- execution_ctx_cls),
- {})
+ eq_(len(mock_rowcount.__get__.mock_calls), 0)
- try:
- r = t.insert().execute({'data': 'd1'}, {'data': 'd2'},
- {'data': 'd3'})
- eq_(t.select().execute().fetchall(), [('d1', ), ('d2', ),
- ('d3', )])
- assert_raises(AssertionError, t.update().execute, {'data'
- : 'd4'})
- assert_raises(AssertionError, t.delete().execute)
- finally:
- engine.dialect.execution_ctx_cls = execution_ctx_cls
+ eq_(
+ engine.execute(t.select()).fetchall(),
+ [('d1', ), ('d2', ), ('d3', )]
+ )
+ eq_(len(mock_rowcount.__get__.mock_calls), 0)
+
+ engine.execute(t.update(), {'data': 'd4'})
+
+ eq_(len(mock_rowcount.__get__.mock_calls), 1)
+
+ engine.execute(t.delete())
+ eq_(len(mock_rowcount.__get__.mock_calls), 2)
@testing.requires.python26
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index d8d92830f..026793c97 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -668,7 +668,7 @@ class ServerVersioningTest(fixtures.MappedTest):
if hasattr(stmt, "_counter"):
return stmt._counter
else:
- stmt._counter = str(counter.next())
+ stmt._counter = str(next(counter))
return stmt._counter
Table('version_table', metadata,
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index 179d2d261..19f5d26c0 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -201,7 +201,7 @@ class ReturnDefaultsTest(fixtures.TablesTest):
@compiles(IncDefault)
def compile(element, compiler, **kw):
- return str(counter.next())
+ return str(next(counter))
Table("t1", metadata,
Column("id", Integer, primary_key=True, test_needs_autoincrement=True),