diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-04-21 17:18:49 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-04-21 17:18:49 -0400 |
commit | 3aff498e4a96eda06f09f09f98e73e135719b388 (patch) | |
tree | f1ca2029cfd147478447d3cb98bae587a8ccb3c2 /test | |
parent | 1f6528ed8581ba63721bdc2a0593a5d39b9c27e0 (diff) | |
parent | fbcdba12f88d88c509fc34eb8aab3f501d1b705b (diff) | |
download | sqlalchemy-3aff498e4a96eda06f09f09f98e73e135719b388.tar.gz |
merge into cymysql branch...
Diffstat (limited to 'test')
41 files changed, 2768 insertions, 1012 deletions
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 2776f05ab..1b7798d06 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -60,4 +60,16 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): def go(): s = select([t1], t1.c.c2 == t2.c.c1) s.compile(dialect=self.dialect) + go() + + def test_select_labels(self): + # give some of the cached type values + # a chance to warm up + s = select([t1], t1.c.c2 == t2.c.c1).apply_labels() + s.compile(dialect=self.dialect) + + @profiling.function_call_count() + def go(): + s = select([t1], t1.c.c2 == t2.c.c1).apply_labels() + s.compile(dialect=self.dialect) go()
\ No newline at end of file diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index aabc0a2bc..57bddc859 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -14,7 +14,7 @@ from sqlalchemy.sql import column from sqlalchemy.processors import to_decimal_processor_factory, \ to_unicode_processor_factory from sqlalchemy.testing.util import gc_collect -from sqlalchemy.util.compat import decimal +import decimal import gc from sqlalchemy.testing import fixtures import weakref @@ -307,6 +307,7 @@ class MemUsageTest(EnsureZeroed): finally: metadata.drop_all() + @testing.crashes('mysql+cymysql', 'blocking with cymysql >= 0.6') def test_unicode_warnings(self): metadata = MetaData(testing.db) table1 = Table('mytable', metadata, Column('col1', Integer, diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 52ba77310..f1cd3fe85 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -13,10 +13,10 @@ from sqlalchemy.engine import url from sqlalchemy.testing import fixtures, AssertsCompiledSQL, \ AssertsExecutionResults, ComparesTables from sqlalchemy import testing -from sqlalchemy.testing import eq_, emits_warning_on, \ - assert_raises_message -from sqlalchemy.util.compat import decimal +from sqlalchemy.testing import emits_warning_on, assert_raises_message +import decimal from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.util.compat import b class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = mssql.dialect() @@ -1210,28 +1210,28 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): eq_([1, 3, 5], [r.id for r in results]) -class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): - @classmethod - def setup_class(cls): - global dialect - dialect = pyodbc.dialect() +class ParseConnectTest(fixtures.TestBase): def test_pyodbc_connect_dsn_trusted(self): + dialect = pyodbc.dialect() u = url.make_url('mssql://mydsn') connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;Trusted_Connection=Yes'], {}], connection) def test_pyodbc_connect_old_style_dsn_trusted(self): + dialect = pyodbc.dialect() u = url.make_url('mssql:///?dsn=mydsn') connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;Trusted_Connection=Yes'], {}], connection) def test_pyodbc_connect_dsn_non_trusted(self): + dialect = pyodbc.dialect() u = url.make_url('mssql://username:password@mydsn') connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_dsn_extra(self): + dialect = pyodbc.dialect() u = \ url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_' 'english&foo=bar') @@ -1241,12 +1241,14 @@ class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): assert ";foo=bar" in dsn_string def test_pyodbc_connect(self): + dialect = pyodbc.dialect() u = url.make_url('mssql://username:password@hostspec/database') connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UI' 'D=username;PWD=password'], {}], connection) def test_pyodbc_connect_comma_port(self): + dialect = pyodbc.dialect() u = \ url.make_url('mssql://username:password@hostspec:12345/data' 'base') @@ -1255,6 +1257,7 @@ class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): 'ase;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_config_port(self): + dialect = pyodbc.dialect() u = \ url.make_url('mssql://username:password@hostspec/database?p' 'ort=12345') @@ -1263,6 +1266,7 @@ class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): 'D=username;PWD=password;port=12345'], {}], connection) def test_pyodbc_extra_connect(self): + dialect = pyodbc.dialect() u = \ url.make_url('mssql://username:password@hostspec/database?L' 'ANGUAGE=us_english&foo=bar') @@ -1275,6 +1279,7 @@ class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): 'username;PWD=password;LANGUAGE=us_english;foo=bar'), True) def test_pyodbc_odbc_connect(self): + dialect = pyodbc.dialect() u = \ url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server' '%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase' @@ -1284,6 +1289,7 @@ class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): 'D=username;PWD=password'], {}], connection) def test_pyodbc_odbc_connect_with_dsn(self): + dialect = pyodbc.dialect() u = \ url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase' '%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword' @@ -1293,6 +1299,7 @@ class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): {}], connection) def test_pyodbc_odbc_connect_ignores_other_values(self): + dialect = pyodbc.dialect() u = \ url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?od' 'bc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer' @@ -1321,7 +1328,22 @@ class ParseConnectTest(fixtures.TestBase, AssertsCompiledSQL): 'user': 'scott', 'database': 'test'}], connection ) - @testing.only_on(['mssql+pyodbc', 'mssql+pymssql'], "FreeTDS specific test") + def test_pymssql_disconnect(self): + dialect = pymssql.dialect() + + for error in [ + 'Adaptive Server connection timed out', + 'message 20003', + "Error 10054", + "Not connected to any MS SQL server", + "Connection is closed" + ]: + eq_(dialect.is_disconnect(error, None, None), True) + + eq_(dialect.is_disconnect("not an error", None, None), False) + + @testing.only_on(['mssql+pyodbc', 'mssql+pymssql'], + "FreeTDS specific test") def test_bad_freetds_warning(self): engine = engines.testing_engine() @@ -1926,6 +1948,21 @@ class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults, ComparesTabl not in list(engine.execute(tbl.select()).first()) engine.execute(tbl.delete()) +class MonkeyPatchedBinaryTest(fixtures.TestBase): + __only_on__ = 'mssql+pymssql' + + def test_unicode(self): + module = __import__('pymssql') + result = module.Binary(u'foo') + eq_(result, u'foo') + + def test_bytes(self): + module = __import__('pymssql') + input = b('\x80\x03]q\x00X\x03\x00\x00\x00oneq\x01a.') + expected_result = input + result = module.Binary(input) + eq_(result, expected_result) + class BinaryTest(fixtures.TestBase, AssertsExecutionResults): """Test the Binary and VarBinary types""" diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index 7384d7bb4..861b28c5f 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -7,12 +7,11 @@ from sqlalchemy import types as sqltypes, exc, schema from sqlalchemy.sql import table, column from sqlalchemy.testing import fixtures, AssertsExecutionResults, AssertsCompiledSQL from sqlalchemy import testing -from sqlalchemy.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.testing import assert_raises, assert_raises_message from sqlalchemy.testing.engines import testing_engine from sqlalchemy.dialects.oracle import cx_oracle, base as oracle from sqlalchemy.engine import default -from sqlalchemy.util import jython -from sqlalchemy.util.compat import decimal +import decimal from sqlalchemy.testing.schema import Table, Column import datetime import os diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 3337fa6ab..005aed1ce 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -17,8 +17,8 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \ from sqlalchemy.orm import Session, mapper, aliased from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql -from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, ARRAY -from sqlalchemy.util.compat import decimal +from sqlalchemy.dialects.postgresql import HSTORE, hstore, array +import decimal from sqlalchemy.testing.util import round_decimal from sqlalchemy.sql import table, column, operators import logging @@ -180,6 +180,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'USING hash (data)', dialect=postgresql.dialect()) + def test_substring(self): + self.assert_compile(func.substring('abc', 1, 2), + 'SUBSTRING(%(substring_1)s FROM %(substring_2)s ' + 'FOR %(substring_3)s)') + self.assert_compile(func.substring('abc', 1), + 'SUBSTRING(%(substring_1)s FROM %(substring_2)s)') + + def test_extract(self): t = table('t', column('col1', DateTime), column('col2', Date), @@ -734,7 +742,6 @@ class NumericInterpretationTest(fixtures.TestBase): def test_numeric_codes(self): from sqlalchemy.dialects.postgresql import pg8000, psycopg2, base - from sqlalchemy.util.compat import decimal for dialect in (pg8000.dialect(), psycopg2.dialect()): @@ -3094,6 +3101,12 @@ class HStoreRoundTripTest(fixtures.TablesTest): engine.connect() return engine + def test_reflect(self): + from sqlalchemy import inspect + insp = inspect(testing.db) + cols = insp.get_columns('data_table') + assert isinstance(cols[2]['type'], HSTORE) + @testing.only_on("postgresql+psycopg2") def test_insert_native(self): engine = testing.db diff --git a/test/engine/test_ddlemit.py b/test/engine/test_ddlemit.py new file mode 100644 index 000000000..3dbd5756a --- /dev/null +++ b/test/engine/test_ddlemit.py @@ -0,0 +1,184 @@ +from sqlalchemy.testing import fixtures +from sqlalchemy.engine.ddl import SchemaGenerator, SchemaDropper +from sqlalchemy.engine import default +from sqlalchemy import MetaData, Table, Column, Integer, Sequence +from sqlalchemy import schema + +class EmitDDLTest(fixtures.TestBase): + def _mock_connection(self, item_exists): + _canary = [] + + class MockDialect(default.DefaultDialect): + supports_sequences = True + + def has_table(self, connection, name, schema): + return item_exists(name) + + def has_sequence(self, connection, name, schema): + return item_exists(name) + + class MockConnection(object): + dialect = MockDialect() + canary = _canary + + def execute(self, item): + _canary.append(item) + + return MockConnection() + + def _mock_create_fixture(self, checkfirst, tables, + item_exists=lambda item: False): + connection = self._mock_connection(item_exists) + + return SchemaGenerator(connection.dialect, connection, + checkfirst=checkfirst, + tables=tables) + + def _mock_drop_fixture(self, checkfirst, tables, + item_exists=lambda item: True): + connection = self._mock_connection(item_exists) + + return SchemaDropper(connection.dialect, connection, + checkfirst=checkfirst, + tables=tables) + + def _table_fixture(self): + m = MetaData() + + return (m, ) + tuple( + Table('t%d' % i, m, Column('x', Integer)) + for i in xrange(1, 6) + ) + + def _table_seq_fixture(self): + m = MetaData() + + s1 = Sequence('s1') + s2 = Sequence('s2') + t1 = Table('t1', m, Column("x", Integer, s1, primary_key=True)) + t2 = Table('t2', m, Column("x", Integer, s2, primary_key=True)) + + return m, t1, t2, s1, s2 + + + def test_create_seq_checkfirst(self): + m, t1, t2, s1, s2 = self._table_seq_fixture() + generator = self._mock_create_fixture(True, [t1, t2], + item_exists=lambda t: t not in ("t1", "s1") + ) + + self._assert_create([t1, s1], generator, m) + + + def test_drop_seq_checkfirst(self): + m, t1, t2, s1, s2 = self._table_seq_fixture() + generator = self._mock_drop_fixture(True, [t1, t2], + item_exists=lambda t: t in ("t1", "s1") + ) + + self._assert_drop([t1, s1], generator, m) + + def test_create_collection_checkfirst(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_create_fixture(True, [t2, t3, t4], + item_exists=lambda t: t not in ("t2", "t4") + ) + + self._assert_create_tables([t2, t4], generator, m) + + def test_drop_collection_checkfirst(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_drop_fixture(True, [t2, t3, t4], + item_exists=lambda t: t in ("t2", "t4") + ) + + self._assert_drop_tables([t2, t4], generator, m) + + def test_create_collection_nocheck(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_create_fixture(False, [t2, t3, t4], + item_exists=lambda t: t not in ("t2", "t4") + ) + + self._assert_create_tables([t2, t3, t4], generator, m) + + def test_create_empty_collection(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_create_fixture(True, [], + item_exists=lambda t: t not in ("t2", "t4") + ) + + self._assert_create_tables([], generator, m) + + def test_drop_empty_collection(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_drop_fixture(True, [], + item_exists=lambda t: t in ("t2", "t4") + ) + + self._assert_drop_tables([], generator, m) + + def test_drop_collection_nocheck(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_drop_fixture(False, [t2, t3, t4], + item_exists=lambda t: t in ("t2", "t4") + ) + + self._assert_drop_tables([t2, t3, t4], generator, m) + + def test_create_metadata_checkfirst(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_create_fixture(True, None, + item_exists=lambda t: t not in ("t2", "t4") + ) + + self._assert_create_tables([t2, t4], generator, m) + + def test_drop_metadata_checkfirst(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_drop_fixture(True, None, + item_exists=lambda t: t in ("t2", "t4") + ) + + self._assert_drop_tables([t2, t4], generator, m) + + def test_create_metadata_nocheck(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_create_fixture(False, None, + item_exists=lambda t: t not in ("t2", "t4") + ) + + self._assert_create_tables([t1, t2, t3, t4, t5], generator, m) + + def test_drop_metadata_nocheck(self): + m, t1, t2, t3, t4, t5 = self._table_fixture() + generator = self._mock_drop_fixture(False, None, + item_exists=lambda t: t in ("t2", "t4") + ) + + self._assert_drop_tables([t1, t2, t3, t4, t5], generator, m) + + def _assert_create_tables(self, elements, generator, argument): + self._assert_ddl(schema.CreateTable, elements, generator, argument) + + def _assert_drop_tables(self, elements, generator, argument): + self._assert_ddl(schema.DropTable, elements, generator, argument) + + def _assert_create(self, elements, generator, argument): + self._assert_ddl( + (schema.CreateTable, schema.CreateSequence), + elements, generator, argument) + + def _assert_drop(self, elements, generator, argument): + self._assert_ddl( + (schema.DropTable, schema.DropSequence), + elements, generator, argument) + + def _assert_ddl(self, ddl_cls, elements, generator, argument): + generator.traverse_single(argument) + for c in generator.connection.canary: + assert isinstance(c, ddl_cls) + assert c.element in elements, "element %r was not expected"\ + % c.element + elements.remove(c.element) + assert not elements, "elements remain in list: %r" % elements diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index d14cde245..203d7bd71 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -13,7 +13,7 @@ import sqlalchemy as tsa from sqlalchemy import testing from sqlalchemy.testing import engines from sqlalchemy.testing.engines import testing_engine -import logging +import logging.handlers from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam from sqlalchemy.engine import result as _result, default from sqlalchemy.engine.base import Connection, Engine diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 6b283654b..9aecb81a9 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -11,7 +11,10 @@ from sqlalchemy import exc from sqlalchemy.testing import fixtures from sqlalchemy.testing.engines import testing_engine -class MockDisconnect(Exception): +class MockError(Exception): + pass + +class MockDisconnect(MockError): pass class MockDBAPI(object): @@ -20,17 +23,23 @@ class MockDBAPI(object): self.connections = weakref.WeakKeyDictionary() def connect(self, *args, **kwargs): return MockConnection(self) - def shutdown(self): + def shutdown(self, explode='execute'): for c in self.connections: - c.explode[0] = True - Error = MockDisconnect + c.explode = explode + Error = MockError class MockConnection(object): def __init__(self, dbapi): dbapi.connections[self] = True - self.explode = [False] + self.explode = "" def rollback(self): - pass + if self.explode == 'rollback': + raise MockDisconnect("Lost the DB connection on rollback") + if self.explode == 'rollback_no_disconnect': + raise MockError( + "something broke on rollback but we didn't lose the connection") + else: + return def commit(self): pass def cursor(self): @@ -42,13 +51,30 @@ class MockCursor(object): def __init__(self, parent): self.explode = parent.explode self.description = () + self.closed = False def execute(self, *args, **kwargs): - if self.explode[0]: - raise MockDisconnect("Lost the DB connection") + if self.explode == 'execute': + raise MockDisconnect("Lost the DB connection on execute") + elif self.explode in ('execute_no_disconnect', ): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif self.explode in ('rollback', 'rollback_no_disconnect'): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif args and "select" in args[0]: + self.description = [('foo', None, None, None, None, None)] else: return + def fetchall(self): + if self.closed: + raise MockError("cursor closed") + return [] + def fetchone(self): + if self.closed: + raise MockError("cursor closed") + return None def close(self): - pass + self.closed = True db, dbapi = None, None class MockReconnectTest(fixtures.TestBase): @@ -167,12 +193,10 @@ class MockReconnectTest(fixtures.TestBase): dbapi.shutdown() - # raises error - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) assert not conn.closed assert conn.invalidated @@ -186,6 +210,112 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.invalidated assert len(dbapi.connections) == 1 + def test_invalidated_close(self): + conn = db.connect() + + dbapi.shutdown() + + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) + + conn.close() + assert conn.closed + assert conn.invalidated + assert_raises_message( + tsa.exc.StatementError, + "This Connection is closed", + conn.execute, select([1]) + ) + + def test_noreconnect_execute_plus_closewresult(self): + conn = db.connect(close_with_result=True) + + dbapi.shutdown("execute_no_disconnect") + + # raises error + assert_raises_message( + tsa.exc.DBAPIError, + "something broke on execute but we didn't lose the connection", + conn.execute, select([1]) + ) + + assert conn.closed + assert not conn.invalidated + + def test_noreconnect_rollback_plus_closewresult(self): + conn = db.connect(close_with_result=True) + + dbapi.shutdown("rollback_no_disconnect") + + # raises error + assert_raises_message( + tsa.exc.DBAPIError, + "something broke on rollback but we didn't lose the connection", + conn.execute, select([1]) + ) + + assert conn.closed + assert not conn.invalidated + + assert_raises_message( + tsa.exc.StatementError, + "This Connection is closed", + conn.execute, select([1]) + ) + + def test_reconnect_on_reentrant(self): + conn = db.connect() + + conn.execute(select([1])) + + assert len(dbapi.connections) == 1 + + dbapi.shutdown("rollback") + + # raises error + assert_raises_message( + tsa.exc.DBAPIError, + "Lost the DB connection on rollback", + conn.execute, select([1]) + ) + + assert not conn.closed + assert conn.invalidated + + def test_reconnect_on_reentrant_plus_closewresult(self): + conn = db.connect(close_with_result=True) + + dbapi.shutdown("rollback") + + # raises error + assert_raises_message( + tsa.exc.DBAPIError, + "Lost the DB connection on rollback", + conn.execute, select([1]) + ) + + assert conn.closed + assert conn.invalidated + + assert_raises_message( + tsa.exc.StatementError, + "This Connection is closed", + conn.execute, select([1]) + ) + + def test_check_disconnect_no_cursor(self): + conn = db.connect() + result = conn.execute("select 1") + result.cursor.close() + conn.close() + assert_raises_message( + tsa.exc.DBAPIError, + "cursor closed", + list, result + ) + class CursorErrTest(fixtures.TestBase): def setup(self): diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index ab78cc3e2..f0372e8ee 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -14,7 +14,8 @@ from sqlalchemy.orm import relationship, create_session, class_mapper, \ Session from sqlalchemy.testing import eq_ from sqlalchemy.util import classproperty -from sqlalchemy.ext.declarative import declared_attr, AbstractConcreteBase, ConcreteBase +from sqlalchemy.ext.declarative import declared_attr, AbstractConcreteBase, \ + ConcreteBase, has_inherited_table from sqlalchemy.testing import fixtures Base = None @@ -1112,6 +1113,46 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): 'concrete':True} self._roundtrip(Employee, Manager, Engineer, Boss) + + def test_has_inherited_table_doesnt_consider_base(self): + class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + + assert not has_inherited_table(A) + + class B(A): + __tablename__ = 'b' + id = Column(Integer, ForeignKey('a.id'), primary_key=True) + + assert has_inherited_table(B) + + def test_has_inherited_table_in_mapper_args(self): + class Test(Base): + __tablename__ = 'test' + id = Column(Integer, primary_key=True) + type = Column(String(20)) + + @declared_attr + def __mapper_args__(cls): + if not has_inherited_table(cls): + ret = { + 'polymorphic_identity': 'default', + 'polymorphic_on': cls.type, + } + else: + ret = {'polymorphic_identity': cls.__name__} + return ret + + class PolyTest(Test): + __tablename__ = 'poly_test' + id = Column(Integer, ForeignKey(Test.id), primary_key=True) + + configure_mappers() + + assert Test.__mapper__.polymorphic_on is Test.__table__.c.type + assert PolyTest.__mapper__.polymorphic_on is Test.__table__.c.type + def test_ok_to_override_type_from_abstract(self): class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity): pass diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index bf268fbbb..34d7d45e0 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -92,6 +92,8 @@ class SerializeTest(fixtures.MappedTest): @testing.requires.python26 # namedtuple workaround not serializable in 2.5 @testing.skip_if(lambda: util.pypy, "pickle sometimes has " "problems here, sometimes not") + @testing.skip_if("postgresql", "Having intermittent problems on jenkins " + "with this test, it's really not that important") def test_query(self): q = Session.query(User).filter(User.name == 'ed' ).options(joinedload(User.addresses)) diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 66991e922..f883a07a7 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1055,6 +1055,73 @@ class FlushTest(fixtures.MappedTest): sess.flush() assert user_roles.count().scalar() == 1 +class JoinedNoFKSortingTest(fixtures.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table("a", metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True) + ) + Table("b", metadata, + Column('id', Integer, primary_key=True) + ) + Table("c", metadata, + Column('id', Integer, primary_key=True) + ) + + @classmethod + def setup_classes(cls): + class A(cls.Basic): + pass + class B(A): + pass + class C(A): + pass + + @classmethod + def setup_mappers(cls): + A, B, C = cls.classes.A, cls.classes.B, cls.classes.C + mapper(A, cls.tables.a) + mapper(B, cls.tables.b, inherits=A, + inherit_condition=cls.tables.a.c.id == cls.tables.b.c.id) + mapper(C, cls.tables.c, inherits=A, + inherit_condition=cls.tables.a.c.id == cls.tables.c.c.id) + + def test_ordering(self): + B, C = self.classes.B, self.classes.C + sess = Session() + sess.add_all([B(), C(), B(), C()]) + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL( + "INSERT INTO a () VALUES ()", + {} + ), + CompiledSQL( + "INSERT INTO a () VALUES ()", + {} + ), + CompiledSQL( + "INSERT INTO a () VALUES ()", + {} + ), + CompiledSQL( + "INSERT INTO a () VALUES ()", + {} + ), + AllOf( + CompiledSQL( + "INSERT INTO b (id) VALUES (:id)", + [{"id": 1}, {"id": 3}] + ), + CompiledSQL( + "INSERT INTO c (id) VALUES (:id)", + [{"id": 2}, {"id": 4}] + ) + ) + ) + class VersioningTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): @@ -1570,6 +1637,53 @@ class OptimizedLoadTest(fixtures.MappedTest): Column('b', String(10)) ) + def test_no_optimize_on_map_to_join(self): + base, sub = self.tables.base, self.tables.sub + + class Base(fixtures.ComparableEntity): + pass + + class JoinBase(fixtures.ComparableEntity): + pass + class SubJoinBase(JoinBase): + pass + + mapper(Base, base) + mapper(JoinBase, base.outerjoin(sub), properties={ + 'id': [base.c.id, sub.c.id], + 'counter': [base.c.counter, sub.c.counter] + }) + mapper(SubJoinBase, inherits=JoinBase) + + sess = Session() + sess.add(Base(data='data')) + sess.commit() + + sjb = sess.query(SubJoinBase).one() + sjb_id = sjb.id + sess.expire(sjb) + + # this should not use the optimized load, + # which assumes discrete tables + def go(): + eq_(sjb.data, 'data') + + self.assert_sql_execution( + testing.db, + go, + CompiledSQL( + "SELECT base.counter AS base_counter, " + "sub.counter AS sub_counter, base.id AS base_id, " + "sub.id AS sub_id, base.data AS base_data, " + "base.type AS base_type, sub.sub AS sub_sub, " + "sub.counter2 AS sub_counter2 FROM base " + "LEFT OUTER JOIN sub ON base.id = sub.id " + "WHERE base.id = :param_1", + {'param_1': sjb_id} + ), + ) + + def test_optimized_passes(self): """"test that the 'optimized load' routine doesn't crash when a column in the join condition is not available.""" @@ -1611,7 +1725,7 @@ class OptimizedLoadTest(fixtures.MappedTest): pass mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={ - 'concat':column_property(sub.c.sub + "|" + sub.c.sub) + 'concat': column_property(sub.c.sub + "|" + sub.c.sub) }) sess = sessionmaker()() s1 = Sub(data='s1data', sub='s1sub') @@ -1630,7 +1744,7 @@ class OptimizedLoadTest(fixtures.MappedTest): pass mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={ - 'concat':column_property(base.c.data + "|" + sub.c.sub) + 'concat': column_property(base.c.data + "|" + sub.c.sub) }) sess = sessionmaker()() s1 = Sub(data='s1data', sub='s1sub') diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index e22848912..1b9acb787 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -650,6 +650,7 @@ class _PolymorphicTestBase(object): count = 5 self.assert_sql_count(testing.db, go, count) + def test_joinedload_on_subclass(self): sess = create_session() expected = [ diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 1fc70fd77..d60c55edd 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -1170,7 +1170,10 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): b1 = B() assert_raises_message( ValueError, - "Object <B at .*> not associated with attribute of type C.a", + 'Bidirectional attribute conflict detected: ' + 'Passing object <B at .*> to attribute "C.a" ' + 'triggers a modify event on attribute "C.b" ' + 'via the backref "B.c".', setattr, c1, 'a', b1 ) @@ -1180,10 +1183,14 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): b1 = B() assert_raises_message( ValueError, - "Object <B at .*> not associated with attribute of type C.a", + 'Bidirectional attribute conflict detected: ' + 'Passing object <B at .*> to attribute "C.a" ' + 'triggers a modify event on attribute "C.b" ' + 'via the backref "B.c".', c1.a.append, b1 ) + def _scalar_fixture(self): class A(object): pass @@ -1225,6 +1232,36 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): return A, B, C + def _broken_collection_fixture(self): + class A(object): + pass + class B(object): + pass + instrumentation.register_class(A) + instrumentation.register_class(B) + + attributes.register_attribute(A, 'b', backref='a1', useobject=True) + attributes.register_attribute(B, 'a1', backref='b', useobject=True, + uselist=True) + + attributes.register_attribute(B, 'a2', backref='b', useobject=True, + uselist=True) + + return A, B + + def test_broken_collection_assertion(self): + A, B = self._broken_collection_fixture() + b1 = B() + a1 = A() + assert_raises_message( + ValueError, + 'Bidirectional attribute conflict detected: ' + 'Passing object <A at .*> to attribute "B.a2" ' + 'triggers a modify event on attribute "B.a1" ' + 'via the backref "A.b".', + b1.a2.append, a1 + ) + class PendingBackrefTest(fixtures.ORMTest): def setup(self): global Post, Blog, called, lazy_load diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index 00d19e792..12196b4e7 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -37,6 +37,22 @@ class CascadeArgTest(fixtures.MappedTest): class Address(cls.Basic): pass + def test_delete_with_passive_deletes_all(self): + User, Address = self.classes.User, self.classes.Address + users, addresses = self.tables.users, self.tables.addresses + + mapper(User, users, properties={ + 'addresses': relationship(Address, + passive_deletes="all", cascade="all, delete-orphan") + }) + mapper(Address, addresses) + assert_raises_message( + sa_exc.ArgumentError, + "On User.addresses, can't set passive_deletes='all' " + "in conjunction with 'delete' or 'delete-orphan' cascade", + configure_mappers + ) + def test_delete_orphan_without_delete(self): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses @@ -69,6 +85,33 @@ class CascadeArgTest(fixtures.MappedTest): orm_util.CascadeOptions("all, delete-orphan"), frozenset) + def test_cascade_assignable(self): + User, Address = self.classes.User, self.classes.Address + users, addresses = self.tables.users, self.tables.addresses + + rel = relationship(Address) + eq_(rel.cascade, set(['save-update', 'merge'])) + rel.cascade = "save-update, merge, expunge" + eq_(rel.cascade, set(['save-update', 'merge', 'expunge'])) + + mapper(User, users, properties={ + 'addresses': rel + }) + am = mapper(Address, addresses) + configure_mappers() + + eq_(rel.cascade, set(['save-update', 'merge', 'expunge'])) + + assert ("addresses", User) not in am._delete_orphans + rel.cascade = "all, delete, delete-orphan" + assert ("addresses", User) in am._delete_orphans + + eq_(rel.cascade, + set(['delete', 'delete-orphan', 'expunge', 'merge', + 'refresh-expire', 'save-update']) + ) + + class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): run_inserts = None diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py index fb32fb0b9..ad6778e97 100644 --- a/test/orm/test_compile.py +++ b/test/orm/test_compile.py @@ -167,8 +167,10 @@ class CompileTest(fixtures.ORMTest): b = Table('b', meta, Column('id', Integer, primary_key=True), Column('a_id', Integer, ForeignKey('a.id'))) - class A(object):pass - class B(object):pass + class A(object): + pass + class B(object): + pass mapper(A, a, properties={ 'b':relationship(B, backref='a') @@ -183,3 +185,29 @@ class CompileTest(fixtures.ORMTest): configure_mappers ) + def test_conflicting_backref_subclass(self): + meta = MetaData() + + a = Table('a', meta, Column('id', Integer, primary_key=True)) + b = Table('b', meta, Column('id', Integer, primary_key=True), + Column('a_id', Integer, ForeignKey('a.id'))) + + class A(object): + pass + class B(object): + pass + class C(B): + pass + + mapper(A, a, properties={ + 'b': relationship(B, backref='a'), + 'c': relationship(C, backref='a') + }) + mapper(B, b) + mapper(C, None, inherits=B) + + assert_raises_message( + sa_exc.ArgumentError, + "Error creating backref", + configure_mappers + ) diff --git a/test/orm/test_default_strategies.py b/test/orm/test_default_strategies.py index b986ac568..c1668cdd4 100644 --- a/test/orm/test_default_strategies.py +++ b/test/orm/test_default_strategies.py @@ -2,7 +2,6 @@ from test.orm import _fixtures from sqlalchemy import testing from sqlalchemy.orm import mapper, relationship, create_session from sqlalchemy import util -from sqlalchemy.util import any import sqlalchemy as sa from sqlalchemy.testing import eq_, assert_raises_message diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 4c566948a..c701a7076 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -174,9 +174,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ) # a little tedious here, adding labels to work around Query's - # auto-labelling. TODO: can we detect only one table in the - # "froms" and then turn off use_labels ? note: this query is - # incorrect SQL with the correlate of users in the FROM list. + # auto-labelling. s = sess.query(addresses.c.id.label('id'), addresses.c.email_address.label('email')).\ filter(addresses.c.user_id == users.c.id).correlate(users).\ @@ -188,7 +186,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name, " "anon_1.email AS anon_1_email " "FROM users JOIN (SELECT addresses.id AS id, " - "addresses.email_address AS email FROM addresses " + "addresses.email_address AS email FROM addresses, users " "WHERE addresses.user_id = users.id) AS anon_1 " "ON anon_1.id = users.id", ) @@ -2322,3 +2320,64 @@ class TestOverlyEagerEquivalentCols(fixtures.MappedTest): filter(Sub1.id==1).one(), b1 ) + +class LabelCollideTest(fixtures.MappedTest): + """Test handling for a label collision. This collision + is handled by core, see ticket:2702 as well as + test/sql/test_selectable->WithLabelsTest. here we want + to make sure the end result is as we expect. + + """ + + @classmethod + def define_tables(cls, metadata): + Table('foo', metadata, + Column('id', Integer, primary_key=True), + Column('bar_id', Integer) + ) + Table('foo_bar', metadata, + Column('id', Integer, primary_key=True), + ) + + @classmethod + def setup_classes(cls): + class Foo(cls.Basic): + pass + class Bar(cls.Basic): + pass + + @classmethod + def setup_mappers(cls): + mapper(cls.classes.Foo, cls.tables.foo) + mapper(cls.classes.Bar, cls.tables.foo_bar) + + @classmethod + def insert_data(cls): + s = Session() + s.add_all([ + cls.classes.Foo(id=1, bar_id=2), + cls.classes.Bar(id=3) + ]) + s.commit() + + def test_overlap_plain(self): + s = Session() + row = s.query(self.classes.Foo, self.classes.Bar).all()[0] + def go(): + eq_(row.Foo.id, 1) + eq_(row.Foo.bar_id, 2) + eq_(row.Bar.id, 3) + # all three columns are loaded independently without + # overlap, no additional SQL to load all attributes + self.assert_sql_count(testing.db, go, 0) + + def test_overlap_subquery(self): + s = Session() + row = s.query(self.classes.Foo, self.classes.Bar).from_self().all()[0] + def go(): + eq_(row.Foo.id, 1) + eq_(row.Foo.bar_id, 2) + eq_(row.Bar.id, 3) + # all three columns are loaded independently without + # overlap, no additional SQL to load all attributes + self.assert_sql_count(testing.db, go, 0)
\ No newline at end of file diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index 3b548f0cd..3f8fc67b6 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -445,6 +445,20 @@ class MapperInitTest(fixtures.ORMTest): # C is not mapped in the current implementation assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, C) + def test_del_warning(self): + class A(object): + def __del__(self): + pass + + assert_raises_message( + sa.exc.SAWarning, + r"__del__\(\) method on class " + "<class 'test.orm.test_instrumentation.A'> will cause " + "unreachable cycles and memory leaks, as SQLAlchemy " + "instrumentation often creates reference cycles. " + "Please remove this method.", + mapper, A, self.fixture() + ) class OnLoadTest(fixtures.ORMTest): """Check that Events.load is not hit in regular attributes operations.""" diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index 8fd38a680..4c0a193a0 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -2105,7 +2105,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest): sess = create_session() eq_(sess.query(Node).filter(Node.children.any(Node.data == 'n3' - )).all(), [Node(data='n1'), Node(data='n2')]) + )).order_by(Node.data).all(), [Node(data='n1'), Node(data='n2')]) def test_contains(self): Node = self.classes.Node diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 8c5b9cd84..6b97fb135 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -407,6 +407,37 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): obj.info["q"] = "p" eq_(obj.info, {"q": "p"}) + def test_info_via_instrumented(self): + m = MetaData() + # create specific tables here as we don't want + # users.c.id.info to be pre-initialized + users = Table('u', m, Column('id', Integer, primary_key=True), + Column('name', String)) + addresses = Table('a', m, Column('id', Integer, primary_key=True), + Column('name', String), + Column('user_id', Integer, ForeignKey('u.id'))) + Address = self.classes.Address + User = self.classes.User + + mapper(User, users, properties={ + "name_lower": column_property(func.lower(users.c.name)), + "addresses": relationship(Address) + }) + mapper(Address, addresses) + + # attr.info goes down to the original Column object + # for the dictionary. The annotated element needs to pass + # this on. + assert 'info' not in users.c.id.__dict__ + is_(User.id.info, users.c.id.info) + assert 'info' in users.c.id.__dict__ + + # for SQL expressions, ORM-level .info + is_(User.name_lower.info, User.name_lower.property.info) + + # same for relationships + is_(User.addresses.info, User.addresses.property.info) + def test_add_property(self): users, addresses, Address = (self.tables.users, @@ -488,7 +519,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert hasattr(User, 'addresses') assert "addresses" in [p.key for p in m1._polymorphic_properties] - def test_replace_property(self): + def test_replace_col_prop_w_syn(self): users, User = self.tables.users, self.classes.User m = mapper(User, users) @@ -514,6 +545,24 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): u.name = 'jacko' assert m._columntoproperty[users.c.name] is m.get_property('_name') + def test_replace_rel_prop_with_rel_warns(self): + users, User = self.tables.users, self.classes.User + addresses, Address = self.tables.addresses, self.classes.Address + + m = mapper(User, users, properties={ + "addresses": relationship(Address) + }) + mapper(Address, addresses) + + assert_raises_message( + sa.exc.SAWarning, + "Property User.addresses on Mapper|User|users being replaced " + "with new property User.addresses; the old property will " + "be discarded", + m.add_property, + "addresses", relationship(Address) + ) + def test_add_column_prop_deannotate(self): User, users = self.classes.User, self.tables.users Address, addresses = self.classes.Address, self.tables.addresses diff --git a/test/orm/test_query.py b/test/orm/test_query.py index f418d2581..ac9c95f41 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -194,22 +194,33 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address self.assert_compile( - select([User]).where(User.id == Address.user_id). - correlate(Address), - "SELECT users.id, users.name FROM users " - "WHERE users.id = addresses.user_id" + select([User.name, Address.id, + select([func.count(Address.id)]).\ + where(User.id == Address.user_id).\ + correlate(User).as_scalar() + ]), + "SELECT users.name, addresses.id, " + "(SELECT count(addresses.id) AS count_1 " + "FROM addresses WHERE users.id = addresses.user_id) AS anon_1 " + "FROM users, addresses" ) def test_correlate_aliased_entity(self): User = self.classes.User Address = self.classes.Address - aa = aliased(Address, name="aa") + uu = aliased(User, name="uu") self.assert_compile( - select([User]).where(User.id == aa.user_id). - correlate(aa), - "SELECT users.id, users.name FROM users " - "WHERE users.id = aa.user_id" + select([uu.name, Address.id, + select([func.count(Address.id)]).\ + where(uu.id == Address.user_id).\ + correlate(uu).as_scalar() + ]), + # curious, "address.user_id = uu.id" is reversed here + "SELECT uu.name, addresses.id, " + "(SELECT count(addresses.id) AS count_1 " + "FROM addresses WHERE addresses.user_id = uu.id) AS anon_1 " + "FROM users AS uu, addresses" ) def test_columns_clause_entity(self): diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index bad3a0dd7..10ba41429 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -1,4 +1,4 @@ -from sqlalchemy.testing import assert_raises, assert_raises_message, eq_, \ +from sqlalchemy.testing import assert_raises_message, eq_, \ AssertsCompiledSQL, is_ from sqlalchemy.testing import fixtures from sqlalchemy.orm import relationships, foreign, remote @@ -119,9 +119,9 @@ class _JoinFixtures(object): support_sync=False, can_be_synced_fn=_can_sync, primaryjoin=and_( - self.three_tab_a.c.id==self.three_tab_b.c.aid, - self.three_tab_c.c.bid==self.three_tab_b.c.id, - self.three_tab_c.c.aid==self.three_tab_a.c.id + self.three_tab_a.c.id == self.three_tab_b.c.aid, + self.three_tab_c.c.bid == self.three_tab_b.c.id, + self.three_tab_c.c.aid == self.three_tab_a.c.id ) ) @@ -215,9 +215,9 @@ class _JoinFixtures(object): self.composite_selfref, self.composite_selfref, primaryjoin=and_( - self.composite_selfref.c.group_id== + self.composite_selfref.c.group_id == func.foo(self.composite_selfref.c.group_id), - self.composite_selfref.c.parent_id== + self.composite_selfref.c.parent_id == self.composite_selfref.c.id ), **kw @@ -230,9 +230,9 @@ class _JoinFixtures(object): self.composite_selfref, self.composite_selfref, primaryjoin=and_( - remote(self.composite_selfref.c.group_id)== + remote(self.composite_selfref.c.group_id) == func.foo(self.composite_selfref.c.group_id), - remote(self.composite_selfref.c.parent_id)== + remote(self.composite_selfref.c.parent_id) == self.composite_selfref.c.id ), **kw @@ -281,58 +281,60 @@ class _JoinFixtures(object): # see test/orm/inheritance/test_abc_inheritance:TestaTobM2O # and others there right = self.base_w_sub_rel.join(self.rel_sub, - self.base_w_sub_rel.c.id==self.rel_sub.c.id + self.base_w_sub_rel.c.id == self.rel_sub.c.id ) return relationships.JoinCondition( self.base_w_sub_rel, right, self.base_w_sub_rel, self.rel_sub, - primaryjoin=self.base_w_sub_rel.c.sub_id==\ + primaryjoin=self.base_w_sub_rel.c.sub_id == \ self.rel_sub.c.id, **kw ) def _join_fixture_o2m_joined_sub_to_base(self, **kw): left = self.base.join(self.sub_w_base_rel, - self.base.c.id==self.sub_w_base_rel.c.id) + self.base.c.id == self.sub_w_base_rel.c.id) return relationships.JoinCondition( left, self.base, self.sub_w_base_rel, self.base, - primaryjoin=self.sub_w_base_rel.c.base_id==self.base.c.id + primaryjoin=self.sub_w_base_rel.c.base_id == self.base.c.id ) def _join_fixture_m2o_joined_sub_to_sub_on_base(self, **kw): # this is a late add - a variant of the test case # in #2491 where we join on the base cols instead. only # m2o has a problem at the time of this test. - left = self.base.join(self.sub, self.base.c.id==self.sub.c.id) - right = self.base.join(self.sub_w_base_rel, self.base.c.id==self.sub_w_base_rel.c.id) + left = self.base.join(self.sub, self.base.c.id == self.sub.c.id) + right = self.base.join(self.sub_w_base_rel, + self.base.c.id == self.sub_w_base_rel.c.id) return relationships.JoinCondition( left, right, self.sub, self.sub_w_base_rel, - primaryjoin=self.sub_w_base_rel.c.base_id==self.base.c.id, + primaryjoin=self.sub_w_base_rel.c.base_id == self.base.c.id, ) def _join_fixture_o2m_joined_sub_to_sub(self, **kw): - left = self.base.join(self.sub, self.base.c.id==self.sub.c.id) - right = self.base.join(self.sub_w_sub_rel, self.base.c.id==self.sub_w_sub_rel.c.id) + left = self.base.join(self.sub, self.base.c.id == self.sub.c.id) + right = self.base.join(self.sub_w_sub_rel, + self.base.c.id == self.sub_w_sub_rel.c.id) return relationships.JoinCondition( left, right, self.sub, self.sub_w_sub_rel, - primaryjoin=self.sub.c.id==self.sub_w_sub_rel.c.sub_id + primaryjoin=self.sub.c.id == self.sub_w_sub_rel.c.sub_id ) def _join_fixture_m2o_sub_to_joined_sub(self, **kw): # see test.orm.test_mapper:MapperTest.test_add_column_prop_deannotate, right = self.base.join(self.right_w_base_rel, - self.base.c.id==self.right_w_base_rel.c.id) + self.base.c.id == self.right_w_base_rel.c.id) return relationships.JoinCondition( self.right_w_base_rel, right, @@ -343,19 +345,19 @@ class _JoinFixtures(object): def _join_fixture_m2o_sub_to_joined_sub_func(self, **kw): # see test.orm.test_mapper:MapperTest.test_add_column_prop_deannotate, right = self.base.join(self.right_w_base_rel, - self.base.c.id==self.right_w_base_rel.c.id) + self.base.c.id == self.right_w_base_rel.c.id) return relationships.JoinCondition( self.right_w_base_rel, right, self.right_w_base_rel, self.right_w_base_rel, - primaryjoin=self.right_w_base_rel.c.base_id==\ + primaryjoin=self.right_w_base_rel.c.base_id == \ func.foo(self.base.c.id) ) def _join_fixture_o2o_joined_sub_to_base(self, **kw): left = self.base.join(self.sub, - self.base.c.id==self.sub.c.id) + self.base.c.id == self.sub.c.id) # see test_relationships->AmbiguousJoinInterpretedAsSelfRef return relationships.JoinCondition( @@ -371,7 +373,7 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=self.left.c.id== + primaryjoin=self.left.c.id == foreign(func.foo(self.right.c.lid)), **kw ) @@ -382,7 +384,7 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=self.left.c.id== + primaryjoin=self.left.c.id == func.foo(self.right.c.lid), consider_as_foreign_keys=[self.right.c.lid], **kw @@ -399,7 +401,7 @@ class _JoinFixtures(object): ) def _assert_raises_no_relevant_fks(self, fn, expr, relname, - primary, *arg, **kw): + primary, *arg, **kw): assert_raises_message( exc.ArgumentError, r"Could not locate any relevant foreign key columns " @@ -414,9 +416,9 @@ class _JoinFixtures(object): ) def _assert_raises_no_equality(self, fn, expr, relname, - primary, *arg, **kw): + primary, *arg, **kw): assert_raises_message( - sa.exc.ArgumentError, + exc.ArgumentError, "Could not locate any simple equality expressions " "involving locally mapped foreign key columns for %s join " "condition '%s' on relationship %s. " @@ -431,7 +433,7 @@ class _JoinFixtures(object): ) def _assert_raises_ambig_join(self, fn, relname, secondary_arg, - *arg, **kw): + *arg, **kw): if secondary_arg is not None: assert_raises_message( exc.AmbiguousForeignKeysError, @@ -455,7 +457,7 @@ class _JoinFixtures(object): fn, *arg, **kw) def _assert_raises_no_join(self, fn, relname, secondary_arg, - *arg, **kw): + *arg, **kw): if secondary_arg is not None: assert_raises_message( exc.NoForeignKeysError, @@ -463,7 +465,8 @@ class _JoinFixtures(object): "parent/child tables on relationship %s - " "there are no foreign keys linking these tables " "via secondary table '%s'. " - "Ensure that referencing columns are associated with a ForeignKey " + "Ensure that referencing columns are associated " + "with a ForeignKey " "or ForeignKeyConstraint, or specify 'primaryjoin' and " "'secondaryjoin' expressions" % (relname, secondary_arg), @@ -474,14 +477,16 @@ class _JoinFixtures(object): "Could not determine join condition between " "parent/child tables on relationship %s - " "there are no foreign keys linking these tables. " - "Ensure that referencing columns are associated with a ForeignKey " + "Ensure that referencing columns are associated " + "with a ForeignKey " "or ForeignKeyConstraint, or specify a 'primaryjoin' " "expression." % (relname,), fn, *arg, **kw) -class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): +class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, + AssertsCompiledSQL): def test_determine_local_remote_pairs_o2o_joined_sub_to_base(self): joincond = self._join_fixture_o2o_joined_sub_to_base() eq_( @@ -580,7 +585,7 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL ] ) - def test_determine_local_remote_compound_1(self): + def test_determine_local_remote_compound_3(self): joincond = self._join_fixture_compound_expression_1() eq_( joincond.local_remote_pairs, @@ -627,8 +632,10 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL eq_( joincond.local_remote_pairs, [ - (self.composite_selfref.c.group_id, self.composite_selfref.c.group_id), - (self.composite_selfref.c.id, self.composite_selfref.c.parent_id), + (self.composite_selfref.c.group_id, + self.composite_selfref.c.group_id), + (self.composite_selfref.c.id, + self.composite_selfref.c.parent_id), ] ) @@ -647,8 +654,10 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL eq_( joincond.local_remote_pairs, [ - (self.composite_selfref.c.group_id, self.composite_selfref.c.group_id), - (self.composite_selfref.c.id, self.composite_selfref.c.parent_id), + (self.composite_selfref.c.group_id, + self.composite_selfref.c.group_id), + (self.composite_selfref.c.id, + self.composite_selfref.c.parent_id), ] ) @@ -713,8 +722,8 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL eq_( j2.local_remote_pairs, [ - (self.m2mright.c.id, self.m2msecondary.c.rid), - (self.m2mleft.c.id, self.m2msecondary.c.lid), + (self.m2mright.c.id, self.m2msecondary.c.rid), + (self.m2mleft.c.id, self.m2msecondary.c.lid), ] ) @@ -997,19 +1006,22 @@ class AdaptedJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): ) class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'default' - def _test_lazy_clause_o2m(self): + def test_lazy_clause_o2m(self): joincond = self._join_fixture_o2m() + lazywhere, bind_to_col, equated_columns = joincond.create_lazy_clause() self.assert_compile( - relationships.create_lazy_clause(joincond), - "" + lazywhere, + ":param_1 = rgt.lid" ) - def _test_lazy_clause_o2m_reverse(self): + def test_lazy_clause_o2m_reverse(self): joincond = self._join_fixture_o2m() + lazywhere, bind_to_col, equated_columns =\ + joincond.create_lazy_clause(reverse_direction=True) self.assert_compile( - relationships.create_lazy_clause(joincond, - reverse_direction=True), - "" + lazywhere, + "lft.id = :param_1" ) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 5c8968842..7c2e8a3b8 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -857,6 +857,150 @@ class SessionStateWFixtureTest(_fixtures.FixtureTest): assert sa.orm.attributes.instance_state(a).session_id is None +class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): + """Test the instance_state._strong_obj link that it + is present only on persistent/pending objects and never + transient/detached. + + """ + run_inserts = None + + def setup(self): + mapper(self.classes.User, self.tables.users) + + def _assert_modified(self, u1): + assert sa.orm.attributes.instance_state(u1).modified + + def _assert_not_modified(self, u1): + assert not sa.orm.attributes.instance_state(u1).modified + + def _assert_cycle(self, u1): + assert sa.orm.attributes.instance_state(u1)._strong_obj is not None + + def _assert_no_cycle(self, u1): + assert sa.orm.attributes.instance_state(u1)._strong_obj is None + + def _persistent_fixture(self): + User = self.classes.User + u1 = User() + u1.name = "ed" + sess = Session() + sess.add(u1) + sess.flush() + return sess, u1 + + def test_transient(self): + User = self.classes.User + u1 = User() + u1.name = 'ed' + self._assert_no_cycle(u1) + self._assert_modified(u1) + + def test_transient_to_pending(self): + User = self.classes.User + u1 = User() + u1.name = 'ed' + self._assert_modified(u1) + self._assert_no_cycle(u1) + sess = Session() + sess.add(u1) + self._assert_cycle(u1) + sess.flush() + self._assert_no_cycle(u1) + self._assert_not_modified(u1) + + def test_dirty_persistent_to_detached_via_expunge(self): + sess, u1 = self._persistent_fixture() + u1.name = 'edchanged' + self._assert_cycle(u1) + sess.expunge(u1) + self._assert_no_cycle(u1) + + def test_dirty_persistent_to_detached_via_close(self): + sess, u1 = self._persistent_fixture() + u1.name = 'edchanged' + self._assert_cycle(u1) + sess.close() + self._assert_no_cycle(u1) + + def test_clean_persistent_to_detached_via_close(self): + sess, u1 = self._persistent_fixture() + self._assert_no_cycle(u1) + self._assert_not_modified(u1) + sess.close() + u1.name = 'edchanged' + self._assert_modified(u1) + self._assert_no_cycle(u1) + + def test_detached_to_dirty_deleted(self): + sess, u1 = self._persistent_fixture() + sess.expunge(u1) + u1.name = 'edchanged' + self._assert_no_cycle(u1) + sess.delete(u1) + self._assert_cycle(u1) + + def test_detached_to_dirty_persistent(self): + sess, u1 = self._persistent_fixture() + sess.expunge(u1) + u1.name = 'edchanged' + self._assert_modified(u1) + self._assert_no_cycle(u1) + sess.add(u1) + self._assert_cycle(u1) + self._assert_modified(u1) + + def test_detached_to_clean_persistent(self): + sess, u1 = self._persistent_fixture() + sess.expunge(u1) + self._assert_no_cycle(u1) + self._assert_not_modified(u1) + sess.add(u1) + self._assert_no_cycle(u1) + self._assert_not_modified(u1) + + def test_move_persistent_clean(self): + sess, u1 = self._persistent_fixture() + sess.close() + s2 = Session() + s2.add(u1) + self._assert_no_cycle(u1) + self._assert_not_modified(u1) + + def test_move_persistent_dirty(self): + sess, u1 = self._persistent_fixture() + u1.name = 'edchanged' + self._assert_cycle(u1) + self._assert_modified(u1) + sess.close() + self._assert_no_cycle(u1) + s2 = Session() + s2.add(u1) + self._assert_cycle(u1) + self._assert_modified(u1) + + @testing.requires.predictable_gc + def test_move_gc_session_persistent_dirty(self): + sess, u1 = self._persistent_fixture() + u1.name = 'edchanged' + self._assert_cycle(u1) + self._assert_modified(u1) + del sess + gc_collect() + self._assert_cycle(u1) + s2 = Session() + s2.add(u1) + self._assert_cycle(u1) + self._assert_modified(u1) + + def test_persistent_dirty_to_expired(self): + sess, u1 = self._persistent_fixture() + u1.name = 'edchanged' + self._assert_cycle(u1) + self._assert_modified(u1) + sess.expire(u1) + self._assert_no_cycle(u1) + self._assert_not_modified(u1) class WeakIdentityMapTest(_fixtures.FixtureTest): run_inserts = None diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index a4cc830ee..3ee94cae9 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -976,6 +976,166 @@ class OrderBySecondaryTest(fixtures.MappedTest): ]) self.assert_sql_count(testing.db, go, 2) + +from .inheritance._poly_fixtures import _Polymorphic, Person, Engineer, Paperwork + +class BaseRelationFromJoinedSubclassTest(_Polymorphic): + @classmethod + def define_tables(cls, metadata): + people = Table('people', metadata, + Column('person_id', Integer, + primary_key=True, + test_needs_autoincrement=True), + Column('name', String(50)), + Column('type', String(30))) + + # to test fully, PK of engineers table must be + # named differently from that of people + engineers = Table('engineers', metadata, + Column('engineer_id', Integer, + ForeignKey('people.person_id'), + primary_key=True), + Column('primary_language', String(50))) + + paperwork = Table('paperwork', metadata, + Column('paperwork_id', Integer, + primary_key=True, + test_needs_autoincrement=True), + Column('description', String(50)), + Column('person_id', Integer, + ForeignKey('people.person_id'))) + + @classmethod + def setup_mappers(cls): + people = cls.tables.people + engineers = cls.tables.engineers + paperwork = cls.tables.paperwork + + mapper(Person, people, + polymorphic_on=people.c.type, + polymorphic_identity='person', + properties={ + 'paperwork': relationship( + Paperwork, order_by=paperwork.c.paperwork_id)}) + + mapper(Engineer, engineers, + inherits=Person, + polymorphic_identity='engineer') + + mapper(Paperwork, paperwork) + + @classmethod + def insert_data(cls): + + e1 = Engineer(primary_language="java") + e2 = Engineer(primary_language="c++") + e1.paperwork = [Paperwork(description="tps report #1"), + Paperwork(description="tps report #2")] + e2.paperwork = [Paperwork(description="tps report #3")] + sess = create_session() + sess.add_all([e1, e2]) + sess.flush() + + def test_correct_subquery_nofrom(self): + sess = create_session() + # use Person.paperwork here just to give the least + # amount of context + q = sess.query(Engineer).\ + filter(Engineer.primary_language == 'java').\ + options(subqueryload(Person.paperwork)) + def go(): + eq_(q.all()[0].paperwork, + [Paperwork(description="tps report #1"), + Paperwork(description="tps report #2")], + + ) + self.assert_sql_execution( + testing.db, + go, + CompiledSQL( + "SELECT people.person_id AS people_person_id, " + "people.name AS people_name, people.type AS people_type, " + "engineers.engineer_id AS engineers_engineer_id, " + "engineers.primary_language AS engineers_primary_language " + "FROM people JOIN engineers ON " + "people.person_id = engineers.engineer_id " + "WHERE engineers.primary_language = :primary_language_1", + {"primary_language_1": "java"} + ), + # ensure we get "people JOIN engineer" here, even though + # primary key "people.person_id" is against "Person" + # *and* the path comes out as "Person.paperwork", still + # want to select from "Engineer" entity + CompiledSQL( + "SELECT paperwork.paperwork_id AS paperwork_paperwork_id, " + "paperwork.description AS paperwork_description, " + "paperwork.person_id AS paperwork_person_id, " + "anon_1.people_person_id AS anon_1_people_person_id " + "FROM (SELECT people.person_id AS people_person_id " + "FROM people JOIN engineers " + "ON people.person_id = engineers.engineer_id " + "WHERE engineers.primary_language = " + ":primary_language_1) AS anon_1 " + "JOIN paperwork " + "ON anon_1.people_person_id = paperwork.person_id " + "ORDER BY anon_1.people_person_id, paperwork.paperwork_id", + {"primary_language_1": "java"} + ) + ) + + def test_correct_subquery_existingfrom(self): + sess = create_session() + # use Person.paperwork here just to give the least + # amount of context + q = sess.query(Engineer).\ + filter(Engineer.primary_language == 'java').\ + join(Engineer.paperwork).\ + filter(Paperwork.description == "tps report #2").\ + options(subqueryload(Person.paperwork)) + def go(): + eq_(q.one().paperwork, + [Paperwork(description="tps report #1"), + Paperwork(description="tps report #2")], + + ) + self.assert_sql_execution( + testing.db, + go, + CompiledSQL( + "SELECT people.person_id AS people_person_id, " + "people.name AS people_name, people.type AS people_type, " + "engineers.engineer_id AS engineers_engineer_id, " + "engineers.primary_language AS engineers_primary_language " + "FROM people JOIN engineers " + "ON people.person_id = engineers.engineer_id " + "JOIN paperwork ON people.person_id = paperwork.person_id " + "WHERE engineers.primary_language = :primary_language_1 " + "AND paperwork.description = :description_1", + {"primary_language_1": "java", + "description_1": "tps report #2"} + ), + CompiledSQL( + "SELECT paperwork.paperwork_id AS paperwork_paperwork_id, " + "paperwork.description AS paperwork_description, " + "paperwork.person_id AS paperwork_person_id, " + "anon_1.people_person_id AS anon_1_people_person_id " + "FROM (SELECT people.person_id AS people_person_id " + "FROM people JOIN engineers ON people.person_id = " + "engineers.engineer_id JOIN paperwork " + "ON people.person_id = paperwork.person_id " + "WHERE engineers.primary_language = :primary_language_1 AND " + "paperwork.description = :description_1) AS anon_1 " + "JOIN paperwork ON anon_1.people_person_id = " + "paperwork.person_id " + "ORDER BY anon_1.people_person_id, paperwork.paperwork_id", + {"primary_language_1": "java", + "description_1": "tps report #2"} + ) + ) + + + + class SelfReferentialTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 7df6ecf91..64b05a131 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -358,18 +358,80 @@ class SessionTransactionTest(FixtureTest): sess.begin, subtransactions=True) sess.close() - def test_no_sql_during_prepare(self): + def test_no_sql_during_commit(self): sess = create_session(bind=testing.db, autocommit=False) @event.listens_for(sess, "after_commit") def go(session): session.execute("select 1") assert_raises_message(sa_exc.InvalidRequestError, - "This session is in 'prepared' state, where no " - "further SQL can be emitted until the " - "transaction is fully committed.", + "This session is in 'committed' state; no further " + "SQL can be emitted within this transaction.", sess.commit) + def test_no_sql_during_prepare(self): + sess = create_session(bind=testing.db, autocommit=False, twophase=True) + + sess.prepare() + + assert_raises_message(sa_exc.InvalidRequestError, + "This session is in 'prepared' state; no further " + "SQL can be emitted within this transaction.", + sess.execute, "select 1") + + def test_no_prepare_wo_twophase(self): + sess = create_session(bind=testing.db, autocommit=False) + + assert_raises_message(sa_exc.InvalidRequestError, + "'twophase' mode not enabled, or not root " + "transaction; can't prepare.", + sess.prepare) + + def test_closed_status_check(self): + sess = create_session() + trans = sess.begin() + trans.rollback() + assert_raises_message( + sa_exc.ResourceClosedError, + "This transaction is closed", + trans.rollback + ) + assert_raises_message( + sa_exc.ResourceClosedError, + "This transaction is closed", + trans.commit + ) + + def test_deactive_status_check(self): + sess = create_session() + trans = sess.begin() + trans2 = sess.begin(subtransactions=True) + trans2.rollback() + assert_raises_message( + sa_exc.InvalidRequestError, + "This Session's transaction has been rolled back by a nested " + "rollback\(\) call. To begin a new transaction, issue " + "Session.rollback\(\) first.", + trans.commit + ) + + def test_deactive_status_check_w_exception(self): + sess = create_session() + trans = sess.begin() + trans2 = sess.begin(subtransactions=True) + try: + raise Exception("test") + except: + trans2.rollback(_capture_exception=True) + assert_raises_message( + sa_exc.InvalidRequestError, + "This Session's transaction has been rolled back due to a " + "previous exception during flush. To begin a new transaction " + "with this Session, first issue Session.rollback\(\). " + "Original exception was: test", + trans.commit + ) + def _inactive_flushed_session_fixture(self): users, User = self.tables.users, self.classes.User diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 7fc728f1d..6be1672e1 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -616,19 +616,25 @@ class ExtraPassiveDeletesTest(fixtures.MappedTest): def test_assertions(self): myothertable, MyOtherClass = self.tables.myothertable, self.classes.MyOtherClass + mytable, MyClass = self.tables.mytable, self.classes.MyClass + mapper(MyClass, mytable, properties={ + 'foo': relationship(MyOtherClass, + passive_deletes='all', + cascade="all") + }) mapper(MyOtherClass, myothertable) + assert_raises_message( sa.exc.ArgumentError, - "Can't set passive_deletes='all' in conjunction with 'delete' " + "On MyClass.foo, can't set passive_deletes='all' in conjunction with 'delete' " "or 'delete-orphan' cascade", - relationship, MyOtherClass, - passive_deletes='all', - cascade="all" + sa.orm.configure_mappers ) def test_extra_passive(self): - myothertable, MyClass, MyOtherClass, mytable = (self.tables.myothertable, + myothertable, MyClass, MyOtherClass, mytable = ( + self.tables.myothertable, self.classes.MyClass, self.classes.MyOtherClass, self.tables.mytable) diff --git a/test/perf/stress_all.py b/test/perf/stress_all.py index d17028530..890ef24a3 100644 --- a/test/perf/stress_all.py +++ b/test/perf/stress_all.py @@ -1,6 +1,6 @@ # -*- encoding: utf8 -*- from datetime import * -from sqlalchemy.util.compat import decimal +import decimal #from fastdec import mpd as Decimal from cPickle import dumps, loads diff --git a/test/profiles.txt b/test/profiles.txt index d83280c2c..d465fa3be 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -33,6 +33,10 @@ test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_postgresql_psycopg2 test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_sqlite_pysqlite_cextensions 135 test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_sqlite_pysqlite_nocextensions 135 +# TEST: test.aaa_profiling.test_compiler.CompileTest.test_select_labels + +test.aaa_profiling.test_compiler.CompileTest.test_select_labels 2.7_sqlite_pysqlite_nocextensions 177 + # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update test.aaa_profiling.test_compiler.CompileTest.test_update 2.5_sqlite_pysqlite_nocextensions 65 @@ -107,6 +111,7 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 2.7_mysql_mysqldb_nocex test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 2.7_postgresql_psycopg2_cextensions 122,18 test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 2.7_postgresql_psycopg2_nocextensions 122,18 test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 2.7_sqlite_pysqlite_cextensions 122,18 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 2.7_sqlite_pysqlite_nocextensions 122,18 # TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect @@ -116,6 +121,7 @@ test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect 2.7_mysql_mysqldb_ test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect 2.7_postgresql_psycopg2_cextensions 82 test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect 2.7_postgresql_psycopg2_nocextensions 82 test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect 2.7_sqlite_pysqlite_cextensions 82 +test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect 2.7_sqlite_pysqlite_nocextensions 82 # TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_second_connect diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 3b8aed23f..9cd893c1a 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -14,8 +14,8 @@ from sqlalchemy.testing import eq_, is_, assert_raises, assert_raises_message from sqlalchemy import testing from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import Integer, String, MetaData, Table, Column, select, \ - func, not_, cast, text, tuple_, exists, delete, update, bindparam,\ - insert, literal, and_, null, type_coerce, alias, or_, literal_column,\ + func, not_, cast, text, tuple_, exists, update, bindparam,\ + literal, and_, null, type_coerce, alias, or_, literal_column,\ Float, TIMESTAMP, Numeric, Date, Text, collate, union, except_,\ intersect, union_all, Boolean, distinct, join, outerjoin, asc, desc,\ over, subquery, case @@ -87,6 +87,7 @@ keyed = Table('keyed', metadata, Column('z', Integer), ) + class SelectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' @@ -424,35 +425,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "AS z FROM keyed) AS anon_2) AS anon_1" ) - def test_dont_overcorrelate(self): - self.assert_compile(select([table1], from_obj=[table1, - table1.select()]), - "SELECT mytable.myid, mytable.name, " - "mytable.description FROM mytable, (SELECT " - "mytable.myid AS myid, mytable.name AS " - "name, mytable.description AS description " - "FROM mytable)") - - def test_full_correlate(self): - # intentional - t = table('t', column('a'), column('b')) - s = select([t.c.a]).where(t.c.a == 1).correlate(t).as_scalar() - - s2 = select([t.c.a, s]) - self.assert_compile(s2, - "SELECT t.a, (SELECT t.a WHERE t.a = :a_1) AS anon_1 FROM t") - - # unintentional - t2 = table('t2', column('c'), column('d')) - s = select([t.c.a]).where(t.c.a == t2.c.d).as_scalar() - s2 = select([t, t2, s]) - assert_raises(exc.InvalidRequestError, str, s2) - - # intentional again - s = s.correlate(t, t2) - s2 = select([t, t2, s]) - self.assert_compile(s, "SELECT t.a WHERE t.a = t2.d") - def test_exists(self): s = select([table1.c.myid]).where(table1.c.myid == 5) @@ -2239,14 +2211,14 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.CompileError, - "Cannot compile Column object until it's 'name' is assigned.", + "Cannot compile Column object until its 'name' is assigned.", str, sel2 ) sel3 = select([my_str]).as_scalar() assert_raises_message( exc.CompileError, - "Cannot compile Column object until it's 'name' is assigned.", + "Cannot compile Column object until its 'name' is assigned.", str, sel3 ) @@ -2488,326 +2460,6 @@ class KwargPropagationTest(fixtures.TestBase): class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' - def test_insert(self): - # generic insert, will create bind params for all columns - self.assert_compile(insert(table1), - "INSERT INTO mytable (myid, name, description) " - "VALUES (:myid, :name, :description)") - - # insert with user-supplied bind params for specific columns, - # cols provided literally - self.assert_compile( - insert(table1, { - table1.c.myid: bindparam('userid'), - table1.c.name: bindparam('username')}), - "INSERT INTO mytable (myid, name) VALUES (:userid, :username)") - - # insert with user-supplied bind params for specific columns, cols - # provided as strings - self.assert_compile( - insert(table1, dict(myid=3, name='jack')), - "INSERT INTO mytable (myid, name) VALUES (:myid, :name)" - ) - - # test with a tuple of params instead of named - self.assert_compile( - insert(table1, (3, 'jack', 'mydescription')), - "INSERT INTO mytable (myid, name, description) VALUES " - "(:myid, :name, :description)", - checkparams={ - 'myid': 3, 'name': 'jack', 'description': 'mydescription'} - ) - - self.assert_compile( - insert(table1, values={ - table1.c.myid: bindparam('userid') - }).values( - {table1.c.name: bindparam('username')}), - "INSERT INTO mytable (myid, name) VALUES (:userid, :username)" - ) - - self.assert_compile( - insert(table1, values=dict(myid=func.lala())), - "INSERT INTO mytable (myid) VALUES (lala())") - - def test_insert_prefix(self): - stmt = table1.insert().prefix_with("A", "B", dialect="mysql").\ - prefix_with("C", "D") - self.assert_compile(stmt, - "INSERT A B C D INTO mytable (myid, name, description) " - "VALUES (%s, %s, %s)", dialect=mysql.dialect() - ) - self.assert_compile(stmt, - "INSERT C D INTO mytable (myid, name, description) " - "VALUES (:myid, :name, :description)") - - def test_inline_default_insert(self): - metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, default=func.foobar())) - self.assert_compile( - table.insert(values={}, inline=True), - "INSERT INTO sometable (foo) VALUES (foobar())") - self.assert_compile( - table.insert(inline=True), - "INSERT INTO sometable (foo) VALUES (foobar())", params={}) - - def test_insert_returning_not_in_default(self): - stmt = table1.insert().returning(table1.c.myid) - assert_raises_message( - exc.CompileError, - "RETURNING is not supported by this dialect's statement compiler.", - stmt.compile - ) - - def test_empty_insert_default(self): - stmt = table1.insert().values({}) # hide from 2to3 - self.assert_compile(stmt, "INSERT INTO mytable () VALUES ()") - - def test_empty_insert_default_values(self): - stmt = table1.insert().values({}) # hide from 2to3 - dialect = default.DefaultDialect() - dialect.supports_empty_insert = dialect.supports_default_values = True - self.assert_compile(stmt, "INSERT INTO mytable DEFAULT VALUES", - dialect=dialect) - - def test_empty_insert_not_supported(self): - stmt = table1.insert().values({}) # hide from 2to3 - dialect = default.DefaultDialect() - dialect.supports_empty_insert = dialect.supports_default_values = False - assert_raises_message( - exc.CompileError, - "The 'default' dialect with current database version " - "settings does not support empty inserts.", - stmt.compile, dialect=dialect - ) - - def test_multivalues_insert_not_supported(self): - stmt = table1.insert().values([{"myid": 1}, {"myid": 2}]) - dialect = default.DefaultDialect() - assert_raises_message( - exc.CompileError, - "The 'default' dialect with current database version settings " - "does not support in-place multirow inserts.", - stmt.compile, dialect=dialect - ) - - def test_multivalues_insert_named(self): - stmt = table1.insert().\ - values([{"myid": 1, "name": 'a', "description": 'b'}, - {"myid": 2, "name": 'c', "description": 'd'}, - {"myid": 3, "name": 'e', "description": 'f'} - ]) - - result = "INSERT INTO mytable (myid, name, description) VALUES " \ - "(:myid_0, :name_0, :description_0), " \ - "(:myid_1, :name_1, :description_1), " \ - "(:myid_2, :name_2, :description_2)" - - dialect = default.DefaultDialect() - dialect.supports_multivalues_insert = True - self.assert_compile(stmt, result, - checkparams={ - 'description_2': 'f', 'name_2': 'e', - 'name_0': 'a', 'name_1': 'c', 'myid_2': 3, - 'description_0': 'b', 'myid_0': 1, - 'myid_1': 2, 'description_1': 'd' - }, - dialect=dialect) - - def test_multivalues_insert_positional(self): - stmt = table1.insert().\ - values([{"myid": 1, "name": 'a', "description": 'b'}, - {"myid": 2, "name": 'c', "description": 'd'}, - {"myid": 3, "name": 'e', "description": 'f'} - ]) - - result = "INSERT INTO mytable (myid, name, description) VALUES " \ - "(%s, %s, %s), " \ - "(%s, %s, %s), " \ - "(%s, %s, %s)" \ - - dialect = default.DefaultDialect() - dialect.supports_multivalues_insert = True - dialect.paramstyle = "format" - dialect.positional = True - self.assert_compile(stmt, result, - checkpositional=(1, 'a', 'b', 2, 'c', 'd', 3, 'e', 'f'), - dialect=dialect) - - def test_multirow_inline_default_insert(self): - metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer, default=func.foobar())) - - stmt = table.insert().\ - values([ - {"id": 1, "data": "data1"}, - {"id": 2, "data": "data2", "foo": "plainfoo"}, - {"id": 3, "data": "data3"}, - ]) - result = "INSERT INTO sometable (id, data, foo) VALUES "\ - "(%(id_0)s, %(data_0)s, foobar()), "\ - "(%(id_1)s, %(data_1)s, %(foo_1)s), "\ - "(%(id_2)s, %(data_2)s, foobar())" - - self.assert_compile(stmt, result, - checkparams={'data_2': 'data3', 'id_0': 1, 'id_2': 3, - 'foo_1': 'plainfoo', 'data_1': 'data2', - 'id_1': 2, 'data_0': 'data1'}, - dialect=postgresql.dialect()) - - def test_multirow_server_default_insert(self): - metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer, server_default=func.foobar())) - - stmt = table.insert().\ - values([ - {"id": 1, "data": "data1"}, - {"id": 2, "data": "data2", "foo": "plainfoo"}, - {"id": 3, "data": "data3"}, - ]) - result = "INSERT INTO sometable (id, data) VALUES "\ - "(%(id_0)s, %(data_0)s), "\ - "(%(id_1)s, %(data_1)s), "\ - "(%(id_2)s, %(data_2)s)" - - self.assert_compile(stmt, result, - checkparams={'data_2': 'data3', 'id_0': 1, 'id_2': 3, - 'data_1': 'data2', - 'id_1': 2, 'data_0': 'data1'}, - dialect=postgresql.dialect()) - - stmt = table.insert().\ - values([ - {"id": 1, "data": "data1", "foo": "plainfoo"}, - {"id": 2, "data": "data2"}, - {"id": 3, "data": "data3", "foo": "otherfoo"}, - ]) - - # note the effect here is that the first set of params - # takes effect for the rest of them, when one is absent - result = "INSERT INTO sometable (id, data, foo) VALUES "\ - "(%(id_0)s, %(data_0)s, %(foo_0)s), "\ - "(%(id_1)s, %(data_1)s, %(foo_0)s), "\ - "(%(id_2)s, %(data_2)s, %(foo_2)s)" - - self.assert_compile(stmt, result, - checkparams={'data_2': 'data3', 'id_0': 1, 'id_2': 3, - 'data_1': 'data2', - "foo_0": "plainfoo", - "foo_2": "otherfoo", - 'id_1': 2, 'data_0': 'data1'}, - dialect=postgresql.dialect()) - - def test_update(self): - self.assert_compile( - update(table1, table1.c.myid == 7), - "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", - params={table1.c.name: 'fred'}) - self.assert_compile( - table1.update().where(table1.c.myid == 7). - values({table1.c.myid: 5}), - "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", - checkparams={'myid': 5, 'myid_1': 7}) - self.assert_compile( - update(table1, table1.c.myid == 7), - "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", - params={'name': 'fred'}) - self.assert_compile( - update(table1, values={table1.c.name: table1.c.myid}), - "UPDATE mytable SET name=mytable.myid") - self.assert_compile( - update(table1, - whereclause=table1.c.name == bindparam('crit'), - values={table1.c.name: 'hi'}), - "UPDATE mytable SET name=:name WHERE mytable.name = :crit", - params={'crit': 'notthere'}, - checkparams={'crit': 'notthere', 'name': 'hi'}) - self.assert_compile( - update(table1, table1.c.myid == 12, - values={table1.c.name: table1.c.myid}), - "UPDATE mytable SET name=mytable.myid, description=" - ":description WHERE mytable.myid = :myid_1", - params={'description': 'test'}, - checkparams={'description': 'test', 'myid_1': 12}) - self.assert_compile( - update(table1, table1.c.myid == 12, - values={table1.c.myid: 9}), - "UPDATE mytable SET myid=:myid, description=:description " - "WHERE mytable.myid = :myid_1", - params={'myid_1': 12, 'myid': 9, 'description': 'test'}) - self.assert_compile( - update(table1, table1.c.myid == 12), - "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", - params={'myid': 18}, checkparams={'myid': 18, 'myid_1': 12}) - s = table1.update(table1.c.myid == 12, values={table1.c.name: 'lala'}) - c = s.compile(column_keys=['id', 'name']) - self.assert_compile( - update(table1, table1.c.myid == 12, - values={table1.c.name: table1.c.myid} - ).values({table1.c.name: table1.c.name + 'foo'}), - "UPDATE mytable SET name=(mytable.name || :name_1), " - "description=:description WHERE mytable.myid = :myid_1", - params={'description': 'test'}) - eq_(str(s), str(c)) - - self.assert_compile(update(table1, - (table1.c.myid == func.hoho(4)) & - (table1.c.name == literal('foo') + - table1.c.name + literal('lala')), - values={ - table1.c.name: table1.c.name + "lala", - table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho')) - }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :param_1), " - "name=(mytable.name || :name_1) " - "WHERE mytable.myid = hoho(:hoho_1) " - "AND mytable.name = :param_2 || " - "mytable.name || :param_3") - - def test_update_prefix(self): - stmt = table1.update().prefix_with("A", "B", dialect="mysql").\ - prefix_with("C", "D") - self.assert_compile(stmt, - "UPDATE A B C D mytable SET myid=%s, name=%s, description=%s", - dialect=mysql.dialect() - ) - self.assert_compile(stmt, - "UPDATE C D mytable SET myid=:myid, name=:name, " - "description=:description") - - def test_aliased_update(self): - talias1 = table1.alias('t1') - self.assert_compile( - update(talias1, talias1.c.myid == 7), - "UPDATE mytable AS t1 SET name=:name WHERE t1.myid = :myid_1", - params={table1.c.name: 'fred'}) - self.assert_compile( - update(talias1, table1.c.myid == 7), - "UPDATE mytable AS t1 SET name=:name FROM " - "mytable WHERE mytable.myid = :myid_1", - params={table1.c.name: 'fred'}) - - def test_update_to_expression(self): - """test update from an expression. - - this logic is triggered currently by a left side that doesn't - have a key. The current supported use case is updating the index - of a Postgresql ARRAY type. - - """ - expr = func.foo(table1.c.myid) - assert not hasattr(expr, "key") - self.assert_compile( - table1.update().values({expr: 'bar'}), - "UPDATE mytable SET foo(myid)=:param_1" - ) def test_correlated_update(self): # test against a straight text subquery @@ -2880,51 +2532,6 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): "AND myothertable.othername = mytable_1.name", dialect=mssql.dialect()) - def test_delete(self): - self.assert_compile( - delete(table1, table1.c.myid == 7), - "DELETE FROM mytable WHERE mytable.myid = :myid_1") - self.assert_compile( - table1.delete().where(table1.c.myid == 7), - "DELETE FROM mytable WHERE mytable.myid = :myid_1") - self.assert_compile( - table1.delete().where(table1.c.myid == 7).\ - where(table1.c.name == 'somename'), - "DELETE FROM mytable WHERE mytable.myid = :myid_1 " - "AND mytable.name = :name_1") - - def test_delete_prefix(self): - stmt = table1.delete().prefix_with("A", "B", dialect="mysql").\ - prefix_with("C", "D") - self.assert_compile(stmt, - "DELETE A B C D FROM mytable", - dialect=mysql.dialect() - ) - self.assert_compile(stmt, - "DELETE C D FROM mytable") - - def test_aliased_delete(self): - talias1 = table1.alias('t1') - self.assert_compile( - delete(talias1).where(talias1.c.myid == 7), - "DELETE FROM mytable AS t1 WHERE t1.myid = :myid_1") - - def test_correlated_delete(self): - # test a non-correlated WHERE clause - s = select([table2.c.othername], table2.c.otherid == 7) - u = delete(table1, table1.c.name == s) - self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = " - "(SELECT myothertable.othername FROM myothertable " - "WHERE myothertable.otherid = :otherid_1)") - - # test one that is actually correlated... - s = select([table2.c.othername], table2.c.otherid == table1.c.myid) - u = table1.delete(table1.c.name == s) - self.assert_compile(u, - "DELETE FROM mytable WHERE mytable.name = (SELECT " - "myothertable.othername FROM myothertable WHERE " - "myothertable.otherid = mytable.myid)") - def test_binds_that_match_columns(self): """test bind params named after column names replace the normal SET/VALUES generation.""" @@ -3189,6 +2796,246 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "(:rem_id, :datatype_id, :value)") +class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_dont_overcorrelate(self): + self.assert_compile(select([table1], from_obj=[table1, + table1.select()]), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable, (SELECT " + "mytable.myid AS myid, mytable.name AS " + "name, mytable.description AS description " + "FROM mytable)") + + def _fixture(self): + t1 = table('t1', column('a')) + t2 = table('t2', column('a')) + return t1, t2, select([t1]).where(t1.c.a == t2.c.a) + + def _assert_where_correlated(self, stmt): + self.assert_compile( + stmt, + "SELECT t2.a FROM t2 WHERE t2.a = " + "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)") + + def _assert_where_all_correlated(self, stmt): + self.assert_compile( + stmt, + "SELECT t1.a, t2.a FROM t1, t2 WHERE t2.a = " + "(SELECT t1.a WHERE t1.a = t2.a)") + + def _assert_where_backwards_correlated(self, stmt): + self.assert_compile( + stmt, + "SELECT t2.a FROM t2 WHERE t2.a = " + "(SELECT t1.a FROM t2 WHERE t1.a = t2.a)") + + def _assert_column_correlated(self, stmt): + self.assert_compile(stmt, + "SELECT t2.a, (SELECT t1.a FROM t1 WHERE t1.a = t2.a) " + "AS anon_1 FROM t2") + + def _assert_column_all_correlated(self, stmt): + self.assert_compile(stmt, + "SELECT t1.a, t2.a, " + "(SELECT t1.a WHERE t1.a = t2.a) AS anon_1 FROM t1, t2") + + def _assert_column_backwards_correlated(self, stmt): + self.assert_compile(stmt, + "SELECT t2.a, (SELECT t1.a FROM t2 WHERE t1.a = t2.a) " + "AS anon_1 FROM t2") + + def _assert_having_correlated(self, stmt): + self.assert_compile(stmt, + "SELECT t2.a FROM t2 HAVING t2.a = " + "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)") + + def _assert_from_uncorrelated(self, stmt): + self.assert_compile(stmt, + "SELECT t2.a, anon_1.a FROM t2, " + "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1") + + def _assert_from_all_uncorrelated(self, stmt): + self.assert_compile(stmt, + "SELECT t1.a, t2.a, anon_1.a FROM t1, t2, " + "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1") + + def _assert_where_uncorrelated(self, stmt): + self.assert_compile(stmt, + "SELECT t2.a FROM t2 WHERE t2.a = " + "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)") + + def _assert_column_uncorrelated(self, stmt): + self.assert_compile(stmt, + "SELECT t2.a, (SELECT t1.a FROM t1, t2 " + "WHERE t1.a = t2.a) AS anon_1 FROM t2") + + def _assert_having_uncorrelated(self, stmt): + self.assert_compile(stmt, + "SELECT t2.a FROM t2 HAVING t2.a = " + "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)") + + def _assert_where_single_full_correlated(self, stmt): + self.assert_compile(stmt, + "SELECT t1.a FROM t1 WHERE t1.a = (SELECT t1.a)") + + def test_correlate_semiauto_where(self): + t1, t2, s1 = self._fixture() + self._assert_where_correlated( + select([t2]).where(t2.c.a == s1.correlate(t2))) + + def test_correlate_semiauto_column(self): + t1, t2, s1 = self._fixture() + self._assert_column_correlated( + select([t2, s1.correlate(t2).as_scalar()])) + + def test_correlate_semiauto_from(self): + t1, t2, s1 = self._fixture() + self._assert_from_uncorrelated( + select([t2, s1.correlate(t2).alias()])) + + def test_correlate_semiauto_having(self): + t1, t2, s1 = self._fixture() + self._assert_having_correlated( + select([t2]).having(t2.c.a == s1.correlate(t2))) + + def test_correlate_except_inclusion_where(self): + t1, t2, s1 = self._fixture() + self._assert_where_correlated( + select([t2]).where(t2.c.a == s1.correlate_except(t1))) + + def test_correlate_except_exclusion_where(self): + t1, t2, s1 = self._fixture() + self._assert_where_backwards_correlated( + select([t2]).where(t2.c.a == s1.correlate_except(t2))) + + def test_correlate_except_inclusion_column(self): + t1, t2, s1 = self._fixture() + self._assert_column_correlated( + select([t2, s1.correlate_except(t1).as_scalar()])) + + def test_correlate_except_exclusion_column(self): + t1, t2, s1 = self._fixture() + self._assert_column_backwards_correlated( + select([t2, s1.correlate_except(t2).as_scalar()])) + + def test_correlate_except_inclusion_from(self): + t1, t2, s1 = self._fixture() + self._assert_from_uncorrelated( + select([t2, s1.correlate_except(t1).alias()])) + + def test_correlate_except_exclusion_from(self): + t1, t2, s1 = self._fixture() + self._assert_from_uncorrelated( + select([t2, s1.correlate_except(t2).alias()])) + + def test_correlate_except_having(self): + t1, t2, s1 = self._fixture() + self._assert_having_correlated( + select([t2]).having(t2.c.a == s1.correlate_except(t1))) + + def test_correlate_auto_where(self): + t1, t2, s1 = self._fixture() + self._assert_where_correlated( + select([t2]).where(t2.c.a == s1)) + + def test_correlate_auto_column(self): + t1, t2, s1 = self._fixture() + self._assert_column_correlated( + select([t2, s1.as_scalar()])) + + def test_correlate_auto_from(self): + t1, t2, s1 = self._fixture() + self._assert_from_uncorrelated( + select([t2, s1.alias()])) + + def test_correlate_auto_having(self): + t1, t2, s1 = self._fixture() + self._assert_having_correlated( + select([t2]).having(t2.c.a == s1)) + + def test_correlate_disabled_where(self): + t1, t2, s1 = self._fixture() + self._assert_where_uncorrelated( + select([t2]).where(t2.c.a == s1.correlate(None))) + + def test_correlate_disabled_column(self): + t1, t2, s1 = self._fixture() + self._assert_column_uncorrelated( + select([t2, s1.correlate(None).as_scalar()])) + + def test_correlate_disabled_from(self): + t1, t2, s1 = self._fixture() + self._assert_from_uncorrelated( + select([t2, s1.correlate(None).alias()])) + + def test_correlate_disabled_having(self): + t1, t2, s1 = self._fixture() + self._assert_having_uncorrelated( + select([t2]).having(t2.c.a == s1.correlate(None))) + + def test_correlate_all_where(self): + t1, t2, s1 = self._fixture() + self._assert_where_all_correlated( + select([t1, t2]).where(t2.c.a == s1.correlate(t1, t2))) + + def test_correlate_all_column(self): + t1, t2, s1 = self._fixture() + self._assert_column_all_correlated( + select([t1, t2, s1.correlate(t1, t2).as_scalar()])) + + def test_correlate_all_from(self): + t1, t2, s1 = self._fixture() + self._assert_from_all_uncorrelated( + select([t1, t2, s1.correlate(t1, t2).alias()])) + + def test_correlate_where_all_unintentional(self): + t1, t2, s1 = self._fixture() + assert_raises_message( + exc.InvalidRequestError, + "returned no FROM clauses due to auto-correlation", + select([t1, t2]).where(t2.c.a == s1).compile + ) + + def test_correlate_from_all_ok(self): + t1, t2, s1 = self._fixture() + self.assert_compile( + select([t1, t2, s1]), + "SELECT t1.a, t2.a, a FROM t1, t2, " + "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a)" + ) + + def test_correlate_auto_where_singlefrom(self): + t1, t2, s1 = self._fixture() + s = select([t1.c.a]) + s2 = select([t1]).where(t1.c.a == s) + self.assert_compile(s2, + "SELECT t1.a FROM t1 WHERE t1.a = " + "(SELECT t1.a FROM t1)") + + def test_correlate_semiauto_where_singlefrom(self): + t1, t2, s1 = self._fixture() + + s = select([t1.c.a]) + + s2 = select([t1]).where(t1.c.a == s.correlate(t1)) + self._assert_where_single_full_correlated(s2) + + def test_correlate_except_semiauto_where_singlefrom(self): + t1, t2, s1 = self._fixture() + + s = select([t1.c.a]) + + s2 = select([t1]).where(t1.c.a == s.correlate_except(t2)) + self._assert_where_single_full_correlated(s2) + + def test_correlate_alone_noeffect(self): + # new as of #2668 + t1, t2, s1 = self._fixture() + self.assert_compile(s1.correlate(t1, t2), + "SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a") + class CoercionTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' @@ -3315,4 +3162,4 @@ class ResultMapTest(fixtures.TestBase): ) is_( comp.result_map['t1_a'][1][2], t1.c.a - )
\ No newline at end of file + ) diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index ab294e1eb..026095c3b 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -7,6 +7,7 @@ from sqlalchemy import testing from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL +from sqlalchemy.sql import table, column class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): __dialect__ = 'default' @@ -753,6 +754,18 @@ class ConstraintAPITest(fixtures.TestBase): c = Index('foo', t.c.a) assert c in t.indexes + def test_auto_append_lowercase_table(self): + t = table('t', column('a')) + t2 = table('t2', column('a')) + for c in ( + UniqueConstraint(t.c.a), + CheckConstraint(t.c.a > 5), + ForeignKeyConstraint([t.c.a], [t2.c.a]), + PrimaryKeyConstraint(t.c.a), + Index('foo', t.c.a) + ): + assert True + def test_tometadata_ok(self): m = MetaData() diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py new file mode 100644 index 000000000..b56731515 --- /dev/null +++ b/test/sql/test_delete.py @@ -0,0 +1,86 @@ +#! coding:utf-8 + +from sqlalchemy import Column, Integer, String, Table, delete, select +from sqlalchemy.dialects import mysql +from sqlalchemy.testing import AssertsCompiledSQL, fixtures + + +class _DeleteTestBase(object): + @classmethod + def define_tables(cls, metadata): + Table('mytable', metadata, + Column('myid', Integer), + Column('name', String(30)), + Column('description', String(50))) + Table('myothertable', metadata, + Column('otherid', Integer), + Column('othername', String(30))) + + +class DeleteTest(_DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_delete(self): + table1 = self.tables.mytable + + self.assert_compile( + delete(table1, table1.c.myid == 7), + 'DELETE FROM mytable WHERE mytable.myid = :myid_1') + + self.assert_compile( + table1.delete().where(table1.c.myid == 7), + 'DELETE FROM mytable WHERE mytable.myid = :myid_1') + + self.assert_compile( + table1.delete(). + where(table1.c.myid == 7). + where(table1.c.name == 'somename'), + 'DELETE FROM mytable ' + 'WHERE mytable.myid = :myid_1 ' + 'AND mytable.name = :name_1') + + def test_prefix_with(self): + table1 = self.tables.mytable + + stmt = table1.delete().\ + prefix_with('A', 'B', dialect='mysql').\ + prefix_with('C', 'D') + + self.assert_compile(stmt, + 'DELETE C D FROM mytable') + + self.assert_compile(stmt, + 'DELETE A B C D FROM mytable', + dialect=mysql.dialect()) + + def test_alias(self): + table1 = self.tables.mytable + + talias1 = table1.alias('t1') + stmt = delete(talias1).where(talias1.c.myid == 7) + + self.assert_compile(stmt, + 'DELETE FROM mytable AS t1 WHERE t1.myid = :myid_1') + + def test_correlated(self): + table1, table2 = self.tables.mytable, self.tables.myothertable + + # test a non-correlated WHERE clause + s = select([table2.c.othername], table2.c.otherid == 7) + self.assert_compile(delete(table1, table1.c.name == s), + 'DELETE FROM mytable ' + 'WHERE mytable.name = (' + 'SELECT myothertable.othername ' + 'FROM myothertable ' + 'WHERE myothertable.otherid = :otherid_1' + ')') + + # test one that is actually correlated... + s = select([table2.c.othername], table2.c.otherid == table1.c.myid) + self.assert_compile(table1.delete(table1.c.name == s), + 'DELETE FROM mytable ' + 'WHERE mytable.name = (' + 'SELECT myothertable.othername ' + 'FROM myothertable ' + 'WHERE myothertable.otherid = mytable.myid' + ')') diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index ae8e28e24..b325b7763 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -8,7 +8,7 @@ from sqlalchemy.testing.engines import all_dialects from sqlalchemy import types as sqltypes from sqlalchemy.sql import functions from sqlalchemy.sql.functions import GenericFunction -from sqlalchemy.util.compat import decimal +import decimal from sqlalchemy import testing from sqlalchemy.testing import fixtures, AssertsCompiledSQL, engines from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py index e868cbe88..8b2abef0e 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_generative.py @@ -590,13 +590,18 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlated_select(self): s = select(['*'], t1.c.col1 == t2.c.col1, from_obj=[t1, t2]).correlate(t2) + class Vis(CloningVisitor): def visit_select(self, select): select.append_whereclause(t1.c.col2 == 7) - self.assert_compile(Vis().traverse(s), - "SELECT * FROM table1 WHERE table1.col1 = table2.col1 " - "AND table1.col2 = :col2_1") + self.assert_compile( + select([t2]).where(t2.c.col1 == Vis().traverse(s)), + "SELECT table2.col1, table2.col2, table2.col3 " + "FROM table2 WHERE table2.col1 = " + "(SELECT * FROM table1 WHERE table1.col1 = table2.col1 " + "AND table1.col2 = :col2_1)" + ) def test_this_thing(self): s = select([t1]).where(t1.c.col1 == 'foo').alias() @@ -616,35 +621,49 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): 'AS table1_1 WHERE table1_1.col1 = ' ':col1_1) AS anon_1') - def test_select_fromtwice(self): + def test_select_fromtwice_one(self): t1a = t1.alias() - s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1) + s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1a) + s = select([t1]).where(t1.c.col1 == s) self.assert_compile(s, - 'SELECT 1 FROM table1 AS table1_1 WHERE ' - 'table1.col1 = table1_1.col1') - + "SELECT table1.col1, table1.col2, table1.col3 FROM table1 " + "WHERE table1.col1 = " + "(SELECT 1 FROM table1, table1 AS table1_1 " + "WHERE table1.col1 = table1_1.col1)" + ) s = CloningVisitor().traverse(s) self.assert_compile(s, - 'SELECT 1 FROM table1 AS table1_1 WHERE ' - 'table1.col1 = table1_1.col1') + "SELECT table1.col1, table1.col2, table1.col3 FROM table1 " + "WHERE table1.col1 = " + "(SELECT 1 FROM table1, table1 AS table1_1 " + "WHERE table1.col1 = table1_1.col1)") + def test_select_fromtwice_two(self): s = select([t1]).where(t1.c.col1 == 'foo').alias() s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s).correlate(t1) - self.assert_compile(s2, - 'SELECT 1 FROM (SELECT table1.col1 AS ' - 'col1, table1.col2 AS col2, table1.col3 AS ' - 'col3 FROM table1 WHERE table1.col1 = ' - ':col1_1) AS anon_1 WHERE table1.col1 = ' - 'anon_1.col1') - s2 = ReplacingCloningVisitor().traverse(s2) - self.assert_compile(s2, - 'SELECT 1 FROM (SELECT table1.col1 AS ' - 'col1, table1.col2 AS col2, table1.col3 AS ' - 'col3 FROM table1 WHERE table1.col1 = ' - ':col1_1) AS anon_1 WHERE table1.col1 = ' - 'anon_1.col1') + s3 = select([t1]).where(t1.c.col1 == s2) + self.assert_compile(s3, + "SELECT table1.col1, table1.col2, table1.col3 " + "FROM table1 WHERE table1.col1 = " + "(SELECT 1 FROM " + "(SELECT table1.col1 AS col1, table1.col2 AS col2, " + "table1.col3 AS col3 FROM table1 " + "WHERE table1.col1 = :col1_1) " + "AS anon_1 WHERE table1.col1 = anon_1.col1)" + ) + + s4 = ReplacingCloningVisitor().traverse(s3) + self.assert_compile(s4, + "SELECT table1.col1, table1.col2, table1.col3 " + "FROM table1 WHERE table1.col1 = " + "(SELECT 1 FROM " + "(SELECT table1.col1 AS col1, table1.col2 AS col2, " + "table1.col3 AS col3 FROM table1 " + "WHERE table1.col1 = :col1_1) " + "AS anon_1 WHERE table1.col1 = anon_1.col1)" + ) class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' @@ -763,67 +782,125 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): 'FROM addresses WHERE users_1.id = ' 'addresses.user_id') - def test_table_to_alias(self): - + def test_table_to_alias_1(self): t1alias = t1.alias('t1alias') vis = sql_util.ClauseAdapter(t1alias) ff = vis.traverse(func.count(t1.c.col1).label('foo')) assert list(_from_objects(ff)) == [t1alias] + def test_table_to_alias_2(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile(vis.traverse(select(['*'], from_obj=[t1])), 'SELECT * FROM table1 AS t1alias') + + def test_table_to_alias_3(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile(select(['*'], t1.c.col1 == t2.c.col2), 'SELECT * FROM table1, table2 WHERE ' 'table1.col1 = table2.col2') + + def test_table_to_alias_4(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile(vis.traverse(select(['*'], t1.c.col1 == t2.c.col2)), 'SELECT * FROM table1 AS t1alias, table2 ' 'WHERE t1alias.col1 = table2.col2') + + def test_table_to_alias_5(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile(vis.traverse(select(['*'], t1.c.col1 == t2.c.col2, from_obj=[t1, t2])), 'SELECT * FROM table1 AS t1alias, table2 ' 'WHERE t1alias.col1 = table2.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t1)), - 'SELECT * FROM table2 WHERE t1alias.col1 = ' - 'table2.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t2)), - 'SELECT * FROM table1 AS t1alias WHERE ' - 't1alias.col1 = table2.col2') + + def test_table_to_alias_6(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) + self.assert_compile( + select([t1alias, t2]).where(t1alias.c.col1 == + vis.traverse(select(['*'], + t1.c.col1 == t2.c.col2, + from_obj=[t1, t2]).correlate(t1))), + "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " + "table2.col1, table2.col2, table2.col3 " + "FROM table1 AS t1alias, table2 WHERE t1alias.col1 = " + "(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)" + ) + + def test_table_to_alias_7(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) + self.assert_compile( + select([t1alias, t2]).where(t1alias.c.col1 == + vis.traverse(select(['*'], + t1.c.col1 == t2.c.col2, + from_obj=[t1, t2]).correlate(t2))), + "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " + "table2.col1, table2.col2, table2.col3 " + "FROM table1 AS t1alias, table2 " + "WHERE t1alias.col1 = " + "(SELECT * FROM table1 AS t1alias " + "WHERE t1alias.col1 = table2.col2)") + + def test_table_to_alias_8(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile(vis.traverse(case([(t1.c.col1 == 5, t1.c.col2)], else_=t1.c.col1)), 'CASE WHEN (t1alias.col1 = :col1_1) THEN ' 't1alias.col2 ELSE t1alias.col1 END') + + def test_table_to_alias_9(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile(vis.traverse(case([(5, t1.c.col2)], value=t1.c.col1, else_=t1.c.col1)), 'CASE t1alias.col1 WHEN :param_1 THEN ' 't1alias.col2 ELSE t1alias.col1 END') + def test_table_to_alias_10(self): s = select(['*'], from_obj=[t1]).alias('foo') self.assert_compile(s.select(), 'SELECT foo.* FROM (SELECT * FROM table1) ' 'AS foo') + + def test_table_to_alias_11(self): + s = select(['*'], from_obj=[t1]).alias('foo') + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile(vis.traverse(s.select()), 'SELECT foo.* FROM (SELECT * FROM table1 ' 'AS t1alias) AS foo') + + def test_table_to_alias_12(self): + s = select(['*'], from_obj=[t1]).alias('foo') self.assert_compile(s.select(), 'SELECT foo.* FROM (SELECT * FROM table1) ' 'AS foo') + + def test_table_to_alias_13(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) ff = vis.traverse(func.count(t1.c.col1).label('foo')) self.assert_compile(select([ff]), 'SELECT count(t1alias.col1) AS foo FROM ' 'table1 AS t1alias') assert list(_from_objects(ff)) == [t1alias] + #def test_table_to_alias_2(self): # TODO: self.assert_compile(vis.traverse(select([func.count(t1.c # .col1).l abel('foo')]), clone=True), "SELECT # count(t1alias.col1) AS foo FROM table1 AS t1alias") + def test_table_to_alias_14(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) t2alias = t2.alias('t2alias') vis.chain(sql_util.ClauseAdapter(t2alias)) self.assert_compile(vis.traverse(select(['*'], t1.c.col1 @@ -831,28 +908,59 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): 'SELECT * FROM table1 AS t1alias, table2 ' 'AS t2alias WHERE t1alias.col1 = ' 't2alias.col2') + + def test_table_to_alias_15(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) + t2alias = t2.alias('t2alias') + vis.chain(sql_util.ClauseAdapter(t2alias)) self.assert_compile(vis.traverse(select(['*'], t1.c.col1 == t2.c.col2, from_obj=[t1, t2])), 'SELECT * FROM table1 AS t1alias, table2 ' 'AS t2alias WHERE t1alias.col1 = ' 't2alias.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t1)), - 'SELECT * FROM table2 AS t2alias WHERE ' - 't1alias.col1 = t2alias.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t2)), - 'SELECT * FROM table1 AS t1alias WHERE ' - 't1alias.col1 = t2alias.col2') + + def test_table_to_alias_16(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) + t2alias = t2.alias('t2alias') + vis.chain(sql_util.ClauseAdapter(t2alias)) + self.assert_compile( + select([t1alias, t2alias]).where( + t1alias.c.col1 == + vis.traverse(select(['*'], + t1.c.col1 == t2.c.col2, + from_obj=[t1, t2]).correlate(t1)) + ), + "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " + "t2alias.col1, t2alias.col2, t2alias.col3 " + "FROM table1 AS t1alias, table2 AS t2alias " + "WHERE t1alias.col1 = " + "(SELECT * FROM table2 AS t2alias " + "WHERE t1alias.col1 = t2alias.col2)" + ) + + def test_table_to_alias_17(self): + t1alias = t1.alias('t1alias') + vis = sql_util.ClauseAdapter(t1alias) + t2alias = t2.alias('t2alias') + vis.chain(sql_util.ClauseAdapter(t2alias)) + self.assert_compile( + t2alias.select().where(t2alias.c.col2 == + vis.traverse(select(['*'], + t1.c.col1 == t2.c.col2, + from_obj=[t1, t2]).correlate(t2))), + 'SELECT t2alias.col1, t2alias.col2, t2alias.col3 ' + 'FROM table2 AS t2alias WHERE t2alias.col2 = ' + '(SELECT * FROM table1 AS t1alias WHERE ' + 't1alias.col1 = t2alias.col2)') def test_include_exclude(self): m = MetaData() - a=Table( 'a',m, - Column( 'id', Integer, primary_key=True), - Column( 'xxx_id', Integer, - ForeignKey( 'a.id', name='adf',use_alter=True ) + a = Table('a', m, + Column('id', Integer, primary_key=True), + Column('xxx_id', Integer, + ForeignKey('a.id', name='adf', use_alter=True) ) ) @@ -1167,93 +1275,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): 'SELECT table1.col1, table1.col2, ' 'table1.col3 FROM table1') - def test_correlation(self): - s = select([t2], t1.c.col1 == t2.c.col1) - self.assert_compile(s, - 'SELECT table2.col1, table2.col2, ' - 'table2.col3 FROM table2, table1 WHERE ' - 'table1.col1 = table2.col1') - s2 = select([t1], t1.c.col2 == s.c.col2) - # dont correlate in a FROM entry - self.assert_compile(s2, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, (SELECT ' - 'table2.col1 AS col1, table2.col2 AS col2, ' - 'table2.col3 AS col3 FROM table2, table1 WHERE ' - 'table1.col1 = table2.col1) WHERE ' - 'table1.col2 = col2') - s3 = s.correlate(None) - self.assert_compile(select([t1], t1.c.col2 == s3.c.col2), - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, (SELECT ' - 'table2.col1 AS col1, table2.col2 AS col2, ' - 'table2.col3 AS col3 FROM table2, table1 ' - 'WHERE table1.col1 = table2.col1) WHERE ' - 'table1.col2 = col2') - # dont correlate in a FROM entry - self.assert_compile(select([t1], t1.c.col2 == s.c.col2), - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, (SELECT ' - 'table2.col1 AS col1, table2.col2 AS col2, ' - 'table2.col3 AS col3 FROM table2, table1 WHERE ' - 'table1.col1 = table2.col1) WHERE ' - 'table1.col2 = col2') - - # but correlate in a WHERE entry - s_w = select([t2.c.col1]).where(t1.c.col1 == t2.c.col1) - self.assert_compile(select([t1], t1.c.col2 == s_w), - 'SELECT table1.col1, table1.col2, table1.col3 ' - 'FROM table1 WHERE table1.col2 = ' - '(SELECT table2.col1 FROM table2 ' - 'WHERE table1.col1 = table2.col1)' - ) - - - s4 = s3.correlate(t1) - self.assert_compile(select([t1], t1.c.col2 == s4.c.col2), - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, (SELECT ' - 'table2.col1 AS col1, table2.col2 AS col2, ' - 'table2.col3 AS col3 FROM table2 WHERE ' - 'table1.col1 = table2.col1) WHERE ' - 'table1.col2 = col2') - - self.assert_compile(select([t1], t1.c.col2 == s3.c.col2), - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, (SELECT ' - 'table2.col1 AS col1, table2.col2 AS col2, ' - 'table2.col3 AS col3 FROM table2, table1 ' - 'WHERE table1.col1 = table2.col1) WHERE ' - 'table1.col2 = col2') - - self.assert_compile(t1.select().where(t1.c.col1 - == 5).order_by(t1.c.col3), - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1 WHERE table1.col1 ' - '= :col1_1 ORDER BY table1.col3') - - # dont correlate in FROM - self.assert_compile(t1.select().select_from(select([t2], - t2.c.col1 - == t1.c.col1)).order_by(t1.c.col3), - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, (SELECT ' - 'table2.col1 AS col1, table2.col2 AS col2, ' - 'table2.col3 AS col3 FROM table2, table1 WHERE ' - 'table2.col1 = table1.col1) ORDER BY ' - 'table1.col3') - - # still works if you actually add that table to correlate() - s = select([t2], t2.c.col1 == t1.c.col1) - s = s.correlate(t1).order_by(t2.c.col3) - - self.assert_compile(t1.select().select_from(s).order_by(t1.c.col3), - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, (SELECT ' - 'table2.col1 AS col1, table2.col2 AS col2, ' - 'table2.col3 AS col3 FROM table2 WHERE ' - 'table2.col1 = table1.col1 ORDER BY ' - 'table2.col3) ORDER BY table1.col3') def test_prefixes(self): s = t1.select() diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py new file mode 100644 index 000000000..cd040538f --- /dev/null +++ b/test/sql/test_insert.py @@ -0,0 +1,312 @@ +#! coding:utf-8 + +from sqlalchemy import Column, Integer, MetaData, String, Table,\ + bindparam, exc, func, insert +from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.engine import default +from sqlalchemy.testing import AssertsCompiledSQL,\ + assert_raises_message, fixtures + + +class _InsertTestBase(object): + @classmethod + def define_tables(cls, metadata): + Table('mytable', metadata, + Column('myid', Integer), + Column('name', String(30)), + Column('description', String(30))) + Table('myothertable', metadata, + Column('otherid', Integer), + Column('othername', String(30))) + + +class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_generic_insert_bind_params_all_columns(self): + table1 = self.tables.mytable + + self.assert_compile(insert(table1), + 'INSERT INTO mytable (myid, name, description) ' + 'VALUES (:myid, :name, :description)') + + def test_insert_with_values_dict(self): + table1 = self.tables.mytable + + checkparams = { + 'myid': 3, + 'name': 'jack' + } + + self.assert_compile(insert(table1, dict(myid=3, name='jack')), + 'INSERT INTO mytable (myid, name) VALUES (:myid, :name)', + checkparams=checkparams) + + def test_insert_with_values_tuple(self): + table1 = self.tables.mytable + + checkparams = { + 'myid': 3, + 'name': 'jack', + 'description': 'mydescription' + } + + self.assert_compile(insert(table1, (3, 'jack', 'mydescription')), + 'INSERT INTO mytable (myid, name, description) ' + 'VALUES (:myid, :name, :description)', + checkparams=checkparams) + + def test_insert_with_values_func(self): + table1 = self.tables.mytable + + self.assert_compile(insert(table1, values=dict(myid=func.lala())), + 'INSERT INTO mytable (myid) VALUES (lala())') + + def test_insert_with_user_supplied_bind_params(self): + table1 = self.tables.mytable + + values = { + table1.c.myid: bindparam('userid'), + table1.c.name: bindparam('username') + } + + self.assert_compile(insert(table1, values), + 'INSERT INTO mytable (myid, name) VALUES (:userid, :username)') + + def test_insert_values(self): + table1 = self.tables.mytable + + values1 = {table1.c.myid: bindparam('userid')} + values2 = {table1.c.name: bindparam('username')} + + self.assert_compile(insert(table1, values=values1).values(values2), + 'INSERT INTO mytable (myid, name) VALUES (:userid, :username)') + + def test_prefix_with(self): + table1 = self.tables.mytable + + stmt = table1.insert().\ + prefix_with('A', 'B', dialect='mysql').\ + prefix_with('C', 'D') + + self.assert_compile(stmt, + 'INSERT C D INTO mytable (myid, name, description) ' + 'VALUES (:myid, :name, :description)') + + self.assert_compile(stmt, + 'INSERT A B C D INTO mytable (myid, name, description) ' + 'VALUES (%s, %s, %s)', dialect=mysql.dialect()) + + def test_inline_default(self): + metadata = MetaData() + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer, default=func.foobar())) + + self.assert_compile(table.insert(values={}, inline=True), + 'INSERT INTO sometable (foo) VALUES (foobar())') + + self.assert_compile(table.insert(inline=True), + 'INSERT INTO sometable (foo) VALUES (foobar())', params={}) + + def test_insert_returning_not_in_default(self): + table1 = self.tables.mytable + + stmt = table1.insert().returning(table1.c.myid) + assert_raises_message( + exc.CompileError, + "RETURNING is not supported by this dialect's statement compiler.", + stmt.compile, + dialect=default.DefaultDialect() + ) + +class EmptyTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_empty_insert_default(self): + table1 = self.tables.mytable + + stmt = table1.insert().values({}) # hide from 2to3 + self.assert_compile(stmt, 'INSERT INTO mytable () VALUES ()') + + def test_supports_empty_insert_true(self): + table1 = self.tables.mytable + + dialect = default.DefaultDialect() + dialect.supports_empty_insert = dialect.supports_default_values = True + + stmt = table1.insert().values({}) # hide from 2to3 + self.assert_compile(stmt, + 'INSERT INTO mytable DEFAULT VALUES', + dialect=dialect) + + def test_supports_empty_insert_false(self): + table1 = self.tables.mytable + + dialect = default.DefaultDialect() + dialect.supports_empty_insert = dialect.supports_default_values = False + + stmt = table1.insert().values({}) # hide from 2to3 + assert_raises_message(exc.CompileError, + "The 'default' dialect with current database version " + "settings does not support empty inserts.", + stmt.compile, dialect=dialect) + + +class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_not_supported(self): + table1 = self.tables.mytable + + dialect = default.DefaultDialect() + stmt = table1.insert().values([{'myid': 1}, {'myid': 2}]) + assert_raises_message( + exc.CompileError, + "The 'default' dialect with current database version settings " + "does not support in-place multirow inserts.", + stmt.compile, dialect=dialect) + + def test_named(self): + table1 = self.tables.mytable + + values = [ + {'myid': 1, 'name': 'a', 'description': 'b'}, + {'myid': 2, 'name': 'c', 'description': 'd'}, + {'myid': 3, 'name': 'e', 'description': 'f'} + ] + + checkparams = { + 'myid_0': 1, + 'myid_1': 2, + 'myid_2': 3, + 'name_0': 'a', + 'name_1': 'c', + 'name_2': 'e', + 'description_0': 'b', + 'description_1': 'd', + 'description_2': 'f', + } + + dialect = default.DefaultDialect() + dialect.supports_multivalues_insert = True + + self.assert_compile(table1.insert().values(values), + 'INSERT INTO mytable (myid, name, description) VALUES ' + '(:myid_0, :name_0, :description_0), ' + '(:myid_1, :name_1, :description_1), ' + '(:myid_2, :name_2, :description_2)', + checkparams=checkparams, dialect=dialect) + + def test_positional(self): + table1 = self.tables.mytable + + values = [ + {'myid': 1, 'name': 'a', 'description': 'b'}, + {'myid': 2, 'name': 'c', 'description': 'd'}, + {'myid': 3, 'name': 'e', 'description': 'f'} + ] + + checkpositional = (1, 'a', 'b', 2, 'c', 'd', 3, 'e', 'f') + + dialect = default.DefaultDialect() + dialect.supports_multivalues_insert = True + dialect.paramstyle = 'format' + dialect.positional = True + + self.assert_compile(table1.insert().values(values), + 'INSERT INTO mytable (myid, name, description) VALUES ' + '(%s, %s, %s), (%s, %s, %s), (%s, %s, %s)', + checkpositional=checkpositional, dialect=dialect) + + def test_inline_default(self): + metadata = MetaData() + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String), + Column('foo', Integer, default=func.foobar())) + + values = [ + {'id': 1, 'data': 'data1'}, + {'id': 2, 'data': 'data2', 'foo': 'plainfoo'}, + {'id': 3, 'data': 'data3'}, + ] + + checkparams = { + 'id_0': 1, + 'id_1': 2, + 'id_2': 3, + 'data_0': 'data1', + 'data_1': 'data2', + 'data_2': 'data3', + 'foo_1': 'plainfoo', + } + + self.assert_compile(table.insert().values(values), + 'INSERT INTO sometable (id, data, foo) VALUES ' + '(%(id_0)s, %(data_0)s, foobar()), ' + '(%(id_1)s, %(data_1)s, %(foo_1)s), ' + '(%(id_2)s, %(data_2)s, foobar())', + checkparams=checkparams, dialect=postgresql.dialect()) + + def test_server_default(self): + metadata = MetaData() + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String), + Column('foo', Integer, server_default=func.foobar())) + + values = [ + {'id': 1, 'data': 'data1'}, + {'id': 2, 'data': 'data2', 'foo': 'plainfoo'}, + {'id': 3, 'data': 'data3'}, + ] + + checkparams = { + 'id_0': 1, + 'id_1': 2, + 'id_2': 3, + 'data_0': 'data1', + 'data_1': 'data2', + 'data_2': 'data3', + } + + self.assert_compile(table.insert().values(values), + 'INSERT INTO sometable (id, data) VALUES ' + '(%(id_0)s, %(data_0)s), ' + '(%(id_1)s, %(data_1)s), ' + '(%(id_2)s, %(data_2)s)', + checkparams=checkparams, dialect=postgresql.dialect()) + + def test_server_default_absent_value(self): + metadata = MetaData() + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String), + Column('foo', Integer, server_default=func.foobar())) + + values = [ + {'id': 1, 'data': 'data1', 'foo': 'plainfoo'}, + {'id': 2, 'data': 'data2'}, + {'id': 3, 'data': 'data3', 'foo': 'otherfoo'}, + ] + + checkparams = { + 'id_0': 1, + 'id_1': 2, + 'id_2': 3, + 'data_0': 'data1', + 'data_1': 'data2', + 'data_2': 'data3', + 'foo_0': 'plainfoo', + 'foo_2': 'otherfoo', + } + + # note the effect here is that the first set of params + # takes effect for the rest of them, when one is absent + self.assert_compile(table.insert().values(values), + 'INSERT INTO sometable (id, data, foo) VALUES ' + '(%(id_0)s, %(data_0)s, %(foo_0)s), ' + '(%(id_1)s, %(data_1)s, %(foo_0)s), ' + '(%(id_2)s, %(data_2)s, %(foo_2)s)', + checkparams=checkparams, dialect=postgresql.dialect()) diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index d7cb8db4a..fd45d303f 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -1,19 +1,15 @@ - -from sqlalchemy import exc as exceptions -from sqlalchemy import testing -from sqlalchemy.testing import engines -from sqlalchemy import select, MetaData, Integer, or_ +from sqlalchemy import exc as exceptions, select, MetaData, Integer, or_ from sqlalchemy.engine import default from sqlalchemy.sql import table, column -from sqlalchemy.testing import assert_raises, eq_ -from sqlalchemy.testing import fixtures, AssertsCompiledSQL -from sqlalchemy.testing.engines import testing_engine +from sqlalchemy.testing import AssertsCompiledSQL, assert_raises, engines,\ + fixtures from sqlalchemy.testing.schema import Table, Column IDENT_LENGTH = 29 class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'DefaultDialect' table1 = table('some_large_named_table', column('this_is_the_primarykey_column'), @@ -25,9 +21,6 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): column('this_is_the_data_column') ) - __dialect__ = 'DefaultDialect' - - def _length_fixture(self, length=IDENT_LENGTH, positional=False): dialect = default.DefaultDialect() dialect.max_identifier_length = length @@ -60,7 +53,7 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): ta = table2.alias() on = table1.c.this_is_the_data_column == ta.c.this_is_the_data_column self.assert_compile( - select([table1, ta]).select_from(table1.join(ta, on)).\ + select([table1, ta]).select_from(table1.join(ta, on)). where(ta.c.this_is_the_data_column == 'data3'), 'SELECT ' 'some_large_named_table.this_is_the_primarykey_column, ' @@ -87,16 +80,9 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): t = Table('this_name_is_too_long_for_what_were_doing_in_this_test', m, Column('foo', Integer)) eng = self._engine_fixture() - for meth in ( - t.create, - t.drop, - m.create_all, - m.drop_all - ): - assert_raises( - exceptions.IdentifierError, - meth, eng - ) + methods = (t.create, t.drop, m.create_all, m.drop_all) + for meth in methods: + assert_raises(exceptions.IdentifierError, meth, eng) def _assert_labeled_table1_select(self, s): table1 = self.table1 @@ -263,7 +249,9 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): dialect=self._length_fixture(positional=True) ) + class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'DefaultDialect' table1 = table('some_large_named_table', column('this_is_the_primarykey_column'), @@ -275,8 +263,6 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): column('this_is_the_data_column') ) - __dialect__ = 'DefaultDialect' - def test_adjustable_1(self): table1 = self.table1 q = table1.select( @@ -404,27 +390,27 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): 'AS _1', dialect=compile_dialect) - def test_adjustable_result_schema_column_1(self): table1 = self.table1 + q = table1.select( table1.c.this_is_the_primarykey_column == 4).apply_labels().\ alias('foo') - dialect = default.DefaultDialect(label_length=10) + dialect = default.DefaultDialect(label_length=10) compiled = q.compile(dialect=dialect) + assert set(compiled.result_map['some_2'][1]).issuperset([ - table1.c.this_is_the_data_column, - 'some_large_named_table_this_is_the_data_column', - 'some_2' + table1.c.this_is_the_data_column, + 'some_large_named_table_this_is_the_data_column', + 'some_2' + ]) - ]) assert set(compiled.result_map['some_1'][1]).issuperset([ - table1.c.this_is_the_primarykey_column, - 'some_large_named_table_this_is_the_primarykey_column', - 'some_1' - - ]) + table1.c.this_is_the_primarykey_column, + 'some_large_named_table_this_is_the_primarykey_column', + 'some_1' + ]) def test_adjustable_result_schema_column_2(self): table1 = self.table1 @@ -434,20 +420,17 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): x = select([q]) dialect = default.DefaultDialect(label_length=10) - compiled = x.compile(dialect=dialect) + assert set(compiled.result_map['this_2'][1]).issuperset([ - q.corresponding_column(table1.c.this_is_the_data_column), - 'this_is_the_data_column', - 'this_2' + q.corresponding_column(table1.c.this_is_the_data_column), + 'this_is_the_data_column', + 'this_2']) - ]) assert set(compiled.result_map['this_1'][1]).issuperset([ - q.corresponding_column(table1.c.this_is_the_primarykey_column), - 'this_is_the_primarykey_column', - 'this_1' - - ]) + q.corresponding_column(table1.c.this_is_the_primarykey_column), + 'this_is_the_primarykey_column', + 'this_1']) def test_table_plus_column_exceeds_length(self): """test that the truncation only occurs when tablename + colname are @@ -490,7 +473,6 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): 'other_thirty_characters_table_.thirty_characters_table_id', dialect=compile_dialect) - def test_colnames_longer_than_labels_lowercase(self): t1 = table('a', column('abcde')) self._test_colnames_longer_than_labels(t1) @@ -507,30 +489,18 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): # 'abcde' is longer than 4, but rendered as itself # needs to have all characters s = select([a1]) - self.assert_compile( - select([a1]), - "SELECT asdf.abcde FROM a AS asdf", - dialect=dialect - ) + self.assert_compile(select([a1]), + 'SELECT asdf.abcde FROM a AS asdf', + dialect=dialect) compiled = s.compile(dialect=dialect) assert set(compiled.result_map['abcde'][1]).issuperset([ - 'abcde', - a1.c.abcde, - 'abcde' - ]) + 'abcde', a1.c.abcde, 'abcde']) # column still there, but short label s = select([a1]).apply_labels() - self.assert_compile( - s, - "SELECT asdf.abcde AS _1 FROM a AS asdf", - dialect=dialect - ) + self.assert_compile(s, + 'SELECT asdf.abcde AS _1 FROM a AS asdf', + dialect=dialect) compiled = s.compile(dialect=dialect) assert set(compiled.result_map['_1'][1]).issuperset([ - 'asdf_abcde', - a1.c.abcde, - '_1' - ]) - - + 'asdf_abcde', a1.c.abcde, '_1']) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 1b8068f22..db2eaa4fa 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -54,7 +54,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): Column(Integer(), ForeignKey('bat.blah'), doc="this is a col"), Column('bar', Integer(), ForeignKey('bat.blah'), primary_key=True, key='bar'), - Column('bar', Integer(), info={'foo':'bar'}), + Column('bar', Integer(), info={'foo': 'bar'}), ]: c2 = col.copy() for attr in ('name', 'type', 'nullable', @@ -148,21 +148,21 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): def test_dupe_tables(self): metadata = self.metadata Table('table1', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', String(20))) + Column('col1', Integer, primary_key=True), + Column('col2', String(20))) metadata.create_all() Table('table1', metadata, autoload=True) def go(): Table('table1', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', String(20))) + Column('col1', Integer, primary_key=True), + Column('col2', String(20))) assert_raises_message( tsa.exc.InvalidRequestError, - "Table 'table1' is already defined for this "\ - "MetaData instance. Specify 'extend_existing=True' "\ - "to redefine options and columns on an existing "\ - "Table object.", + "Table 'table1' is already defined for this " + "MetaData instance. Specify 'extend_existing=True' " + "to redefine options and columns on an existing " + "Table object.", go ) @@ -544,23 +544,23 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): for i, (name, metadata, schema, quote_schema, exp_schema, exp_quote_schema) in enumerate([ - ('t1', m1, None, None, 'sch1', None), - ('t2', m1, 'sch2', None, 'sch2', None), - ('t3', m1, 'sch2', True, 'sch2', True), - ('t4', m1, 'sch1', None, 'sch1', None), - ('t1', m2, None, None, 'sch1', True), - ('t2', m2, 'sch2', None, 'sch2', None), - ('t3', m2, 'sch2', True, 'sch2', True), - ('t4', m2, 'sch1', None, 'sch1', None), - ('t1', m3, None, None, 'sch1', False), - ('t2', m3, 'sch2', None, 'sch2', None), - ('t3', m3, 'sch2', True, 'sch2', True), - ('t4', m3, 'sch1', None, 'sch1', None), - ('t1', m4, None, None, None, None), - ('t2', m4, 'sch2', None, 'sch2', None), - ('t3', m4, 'sch2', True, 'sch2', True), - ('t4', m4, 'sch1', None, 'sch1', None), - ]): + ('t1', m1, None, None, 'sch1', None), + ('t2', m1, 'sch2', None, 'sch2', None), + ('t3', m1, 'sch2', True, 'sch2', True), + ('t4', m1, 'sch1', None, 'sch1', None), + ('t1', m2, None, None, 'sch1', True), + ('t2', m2, 'sch2', None, 'sch2', None), + ('t3', m2, 'sch2', True, 'sch2', True), + ('t4', m2, 'sch1', None, 'sch1', None), + ('t1', m3, None, None, 'sch1', False), + ('t2', m3, 'sch2', None, 'sch2', None), + ('t3', m3, 'sch2', True, 'sch2', True), + ('t4', m3, 'sch1', None, 'sch1', None), + ('t1', m4, None, None, None, None), + ('t2', m4, 'sch2', None, 'sch2', None), + ('t3', m4, 'sch2', True, 'sch2', True), + ('t4', m4, 'sch1', None, 'sch1', None), + ]): kw = {} if schema is not None: kw['schema'] = schema @@ -568,10 +568,12 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): kw['quote_schema'] = quote_schema t = Table(name, metadata, **kw) eq_(t.schema, exp_schema, "test %d, table schema" % i) - eq_(t.quote_schema, exp_quote_schema, "test %d, table quote_schema" % i) + eq_(t.quote_schema, exp_quote_schema, + "test %d, table quote_schema" % i) seq = Sequence(name, metadata=metadata, **kw) eq_(seq.schema, exp_schema, "test %d, seq schema" % i) - eq_(seq.quote_schema, exp_quote_schema, "test %d, seq quote_schema" % i) + eq_(seq.quote_schema, exp_quote_schema, + "test %d, seq quote_schema" % i) def test_manual_dependencies(self): meta = MetaData() @@ -696,8 +698,8 @@ class TableTest(fixtures.TestBase, AssertsCompiledSQL): Column("col1", Integer), prefixes=["VIRTUAL"]) self.assert_compile( - schema.CreateTable(table2), - "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)" + schema.CreateTable(table2), + "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)" ) def test_table_info(self): @@ -940,7 +942,7 @@ class UseExistingTest(fixtures.TablesTest): Unicode), autoload=True) assert_raises_message( exc.InvalidRequestError, - "Table 'users' is already defined for this "\ + "Table 'users' is already defined for this "\ "MetaData instance.", go ) @@ -1551,7 +1553,8 @@ class CatchAllEventsTest(fixtures.TestBase): def test_all_events(self): canary = [] def before_attach(obj, parent): - canary.append("%s->%s" % (obj.__class__.__name__, parent.__class__.__name__)) + canary.append("%s->%s" % (obj.__class__.__name__, + parent.__class__.__name__)) def after_attach(obj, parent): canary.append("%s->%s" % (obj.__class__.__name__, parent)) @@ -1586,7 +1589,8 @@ class CatchAllEventsTest(fixtures.TestBase): def evt(target): def before_attach(obj, parent): - canary.append("%s->%s" % (target.__name__, parent.__class__.__name__)) + canary.append("%s->%s" % (target.__name__, + parent.__class__.__name__)) def after_attach(obj, parent): canary.append("%s->%s" % (target.__name__, parent)) @@ -1594,7 +1598,8 @@ class CatchAllEventsTest(fixtures.TestBase): event.listen(target, "after_parent_attach", after_attach) for target in [ - schema.ForeignKeyConstraint, schema.PrimaryKeyConstraint, schema.UniqueConstraint, + schema.ForeignKeyConstraint, schema.PrimaryKeyConstraint, + schema.UniqueConstraint, schema.CheckConstraint ]: evt(target) @@ -1615,11 +1620,12 @@ class CatchAllEventsTest(fixtures.TestBase): eq_( canary, [ - 'PrimaryKeyConstraint->Table', 'PrimaryKeyConstraint->t1', - 'ForeignKeyConstraint->Table', 'ForeignKeyConstraint->t1', - 'UniqueConstraint->Table', 'UniqueConstraint->t1', - 'PrimaryKeyConstraint->Table', 'PrimaryKeyConstraint->t2', - 'CheckConstraint->Table', 'CheckConstraint->t2', - 'UniqueConstraint->Table', 'UniqueConstraint->t2' + 'PrimaryKeyConstraint->Table', 'PrimaryKeyConstraint->t1', + 'ForeignKeyConstraint->Table', 'ForeignKeyConstraint->t1', + 'UniqueConstraint->Table', 'UniqueConstraint->t1', + 'PrimaryKeyConstraint->Table', 'PrimaryKeyConstraint->t2', + 'CheckConstraint->Table', 'CheckConstraint->t2', + 'UniqueConstraint->Table', 'UniqueConstraint->t2' ] ) + diff --git a/test/sql/test_query.py b/test/sql/test_query.py index b5f50aeea..293e629c8 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -190,10 +190,27 @@ class QueryTest(fixtures.TestBase): try: table.create(bind=engine, checkfirst=True) i = insert_values(engine, table, values) - assert i == assertvalues, "tablename: %s %r %r" % (table.name, repr(i), repr(assertvalues)) + assert i == assertvalues, "tablename: %s %r %r" % \ + (table.name, repr(i), repr(assertvalues)) finally: table.drop(bind=engine) + @testing.only_on('sqlite+pysqlite') + @testing.provide_metadata + def test_lastrowid_zero(self): + from sqlalchemy.dialects import sqlite + eng = engines.testing_engine() + class ExcCtx(sqlite.base.SQLiteExecutionContext): + def get_lastrowid(self): + return 0 + eng.dialect.execution_ctx_cls = ExcCtx + t = Table('t', MetaData(), Column('x', Integer, primary_key=True), + Column('y', Integer)) + t.create(eng) + r = eng.execute(t.insert().values(y=5)) + eq_(r.inserted_primary_key, [0]) + + @testing.fails_on('sqlite', "sqlite autoincremnt doesn't work with composite pks") def test_misordered_lastrow(self): related = Table('related', metadata, @@ -1011,6 +1028,22 @@ class QueryTest(fixtures.TestBase): lambda: row[u2.c.user_id] ) + def test_ambiguous_column_contains(self): + # ticket 2702. in 0.7 we'd get True, False. + # in 0.8, both columns are present so it's True; + # but when they're fetched you'll get the ambiguous error. + users.insert().execute(user_id=1, user_name='john') + result = select([ + users.c.user_id, + addresses.c.user_id]).\ + select_from(users.outerjoin(addresses)).execute() + row = result.first() + + eq_( + set([users.c.user_id in row, addresses.c.user_id in row]), + set([True]) + ) + def test_ambiguous_column_by_col_plus_label(self): users.insert().execute(user_id=1, user_name='john') result = select([users.c.user_id, diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index a182444e9..6a42b0625 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -88,26 +88,6 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result.fetchall(), [(1,)]) - @testing.fails_on('postgresql', 'undefined behavior') - @testing.fails_on('oracle+cx_oracle', 'undefined behavior') - @testing.crashes('mssql+mxodbc', 'Raises an error') - def test_insert_returning_execmany(self): - - # return value is documented as failing with psycopg2/executemany - result2 = table.insert().returning(table).execute( - [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - - if testing.against('mssql+zxjdbc'): - # jtds apparently returns only the first row - eq_(result2.fetchall(), [(2, 2, False, None)]) - elif testing.against('firebird', 'mssql', 'oracle'): - # Multiple inserts only return the last row - eq_(result2.fetchall(), [(3, 3, True, None)]) - else: - # nobody does this as far as we know (pg8000?) - eq_(result2.fetchall(), [(2, 2, False, None), (3, 3, True, None)]) - - @testing.requires.multivalues_inserts def test_multirow_returning(self): ins = table.insert().returning(table.c.id, table.c.persons).values( diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 30052a806..e881298a7 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1587,3 +1587,156 @@ class AnnotationsTest(fixtures.TestBase): comp2 = c2.comparator assert (c2 == 5).left._annotations == {"foo": "bar", "bat": "hoho"} + +class WithLabelsTest(fixtures.TestBase): + def _assert_labels_warning(self, s): + assert_raises_message( + exc.SAWarning, + "replaced by another column with the same key", + lambda: s.c + ) + + def _assert_result_keys(self, s, keys): + compiled = s.compile() + eq_(set(compiled.result_map), set(keys)) + + def _assert_subq_result_keys(self, s, keys): + compiled = s.select().compile() + eq_(set(compiled.result_map), set(keys)) + + def _names_overlap(self): + m = MetaData() + t1 = Table('t1', m, Column('x', Integer)) + t2 = Table('t2', m, Column('x', Integer)) + return select([t1, t2]) + + def test_names_overlap_nolabel(self): + sel = self._names_overlap() + self._assert_labels_warning(sel) + self._assert_result_keys(sel, ['x']) + + def test_names_overlap_label(self): + sel = self._names_overlap().apply_labels() + eq_( + sel.c.keys(), + ['t1_x', 't2_x'] + ) + self._assert_result_keys(sel, ['t1_x', 't2_x']) + + def _names_overlap_keys_dont(self): + m = MetaData() + t1 = Table('t1', m, Column('x', Integer, key='a')) + t2 = Table('t2', m, Column('x', Integer, key='b')) + return select([t1, t2]) + + def test_names_overlap_keys_dont_nolabel(self): + sel = self._names_overlap_keys_dont() + eq_( + sel.c.keys(), + ['a', 'b'] + ) + self._assert_result_keys(sel, ['x']) + + def test_names_overlap_keys_dont_label(self): + sel = self._names_overlap_keys_dont().apply_labels() + eq_( + sel.c.keys(), + ['t1_a', 't2_b'] + ) + self._assert_result_keys(sel, ['t1_x', 't2_x']) + + def _labels_overlap(self): + m = MetaData() + t1 = Table('t', m, Column('x_id', Integer)) + t2 = Table('t_x', m, Column('id', Integer)) + return select([t1, t2]) + + def test_labels_overlap_nolabel(self): + sel = self._labels_overlap() + eq_( + sel.c.keys(), + ['x_id', 'id'] + ) + self._assert_result_keys(sel, ['x_id', 'id']) + + def test_labels_overlap_label(self): + sel = self._labels_overlap().apply_labels() + t2 = sel.froms[1] + eq_( + sel.c.keys(), + ['t_x_id', t2.c.id.anon_label] + ) + self._assert_result_keys(sel, ['t_x_id', 'id_1']) + self._assert_subq_result_keys(sel, ['t_x_id', 'id_1']) + + def _labels_overlap_keylabels_dont(self): + m = MetaData() + t1 = Table('t', m, Column('x_id', Integer, key='a')) + t2 = Table('t_x', m, Column('id', Integer, key='b')) + return select([t1, t2]) + + def test_labels_overlap_keylabels_dont_nolabel(self): + sel = self._labels_overlap_keylabels_dont() + eq_(sel.c.keys(), ['a', 'b']) + self._assert_result_keys(sel, ['x_id', 'id']) + + def test_labels_overlap_keylabels_dont_label(self): + sel = self._labels_overlap_keylabels_dont().apply_labels() + eq_(sel.c.keys(), ['t_a', 't_x_b']) + self._assert_result_keys(sel, ['t_x_id', 'id_1']) + + def _keylabels_overlap_labels_dont(self): + m = MetaData() + t1 = Table('t', m, Column('a', Integer, key='x_id')) + t2 = Table('t_x', m, Column('b', Integer, key='id')) + return select([t1, t2]) + + def test_keylabels_overlap_labels_dont_nolabel(self): + sel = self._keylabels_overlap_labels_dont() + eq_(sel.c.keys(), ['x_id', 'id']) + self._assert_result_keys(sel, ['a', 'b']) + + def test_keylabels_overlap_labels_dont_label(self): + sel = self._keylabels_overlap_labels_dont().apply_labels() + t2 = sel.froms[1] + eq_(sel.c.keys(), ['t_x_id', t2.c.id.anon_label]) + self._assert_result_keys(sel, ['t_a', 't_x_b']) + self._assert_subq_result_keys(sel, ['t_a', 't_x_b']) + + def _keylabels_overlap_labels_overlap(self): + m = MetaData() + t1 = Table('t', m, Column('x_id', Integer, key='x_a')) + t2 = Table('t_x', m, Column('id', Integer, key='a')) + return select([t1, t2]) + + def test_keylabels_overlap_labels_overlap_nolabel(self): + sel = self._keylabels_overlap_labels_overlap() + eq_(sel.c.keys(), ['x_a', 'a']) + self._assert_result_keys(sel, ['x_id', 'id']) + self._assert_subq_result_keys(sel, ['x_id', 'id']) + + def test_keylabels_overlap_labels_overlap_label(self): + sel = self._keylabels_overlap_labels_overlap().apply_labels() + t2 = sel.froms[1] + eq_(sel.c.keys(), ['t_x_a', t2.c.a.anon_label]) + self._assert_result_keys(sel, ['t_x_id', 'id_1']) + self._assert_subq_result_keys(sel, ['t_x_id', 'id_1']) + + def _keys_overlap_names_dont(self): + m = MetaData() + t1 = Table('t1', m, Column('a', Integer, key='x')) + t2 = Table('t2', m, Column('b', Integer, key='x')) + return select([t1, t2]) + + def test_keys_overlap_names_dont_nolabel(self): + sel = self._keys_overlap_names_dont() + self._assert_labels_warning(sel) + self._assert_result_keys(sel, ['a', 'b']) + + def test_keys_overlap_names_dont_label(self): + sel = self._keys_overlap_names_dont().apply_labels() + eq_( + sel.c.keys(), + ['t1_x', 't2_x'] + ) + self._assert_result_keys(sel, ['t1_a', 't2_b']) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 3c981e539..fac22a205 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -15,7 +15,6 @@ from sqlalchemy import testing from sqlalchemy.testing import AssertsCompiledSQL, AssertsExecutionResults, \ engines, pickleable from sqlalchemy.testing.util import picklers -from sqlalchemy.util.compat import decimal from sqlalchemy.testing.util import round_decimal from sqlalchemy.testing import fixtures diff --git a/test/sql/test_update.py b/test/sql/test_update.py index b46489cd2..a8df86cd2 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -1,55 +1,53 @@ -from sqlalchemy.testing import eq_, assert_raises_message, assert_raises, AssertsCompiledSQL -import datetime from sqlalchemy import * -from sqlalchemy import exc, sql, util -from sqlalchemy.engine import default, base from sqlalchemy import testing -from sqlalchemy.testing import fixtures -from sqlalchemy.testing.schema import Table, Column from sqlalchemy.dialects import mysql +from sqlalchemy.testing import AssertsCompiledSQL, eq_, fixtures +from sqlalchemy.testing.schema import Table, Column + class _UpdateFromTestBase(object): @classmethod def define_tables(cls, metadata): + Table('mytable', metadata, + Column('myid', Integer), + Column('name', String(30)), + Column('description', String(50))) + Table('myothertable', metadata, + Column('otherid', Integer), + Column('othername', String(30))) Table('users', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False), - ) - + test_needs_autoincrement=True), + Column('name', String(30), nullable=False)) Table('addresses', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), Column('name', String(30), nullable=False), - Column('email_address', String(50), nullable=False), - ) - - Table("dingalings", metadata, + Column('email_address', String(50), nullable=False)) + Table('dingalings', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + test_needs_autoincrement=True), Column('address_id', None, ForeignKey('addresses.id')), - Column('data', String(30)), - ) + Column('data', String(30))) @classmethod def fixtures(cls): return dict( - users = ( + users=( ('id', 'name'), (7, 'jack'), (8, 'ed'), (9, 'fred'), (10, 'chuck') ), - addresses = ( ('id', 'user_id', 'name', 'email_address'), - (1, 7, 'x', "jack@bean.com"), - (2, 8, 'x', "ed@wood.com"), - (3, 8, 'x', "ed@bettyboop.com"), - (4, 8, 'x', "ed@lala.com"), - (5, 9, 'x', "fred@fred.com") + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'x', 'ed@wood.com'), + (3, 8, 'x', 'ed@bettyboop.com'), + (4, 8, 'x', 'ed@lala.com'), + (5, 9, 'x', 'fred@fred.com') ), dingalings = ( ('id', 'address_id', 'data'), @@ -59,288 +57,462 @@ class _UpdateFromTestBase(object): ) -class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): +class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_update_1(self): + table1 = self.tables.mytable + + self.assert_compile( + update(table1, table1.c.myid == 7), + 'UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1', + params={table1.c.name: 'fred'}) + + def test_update_2(self): + table1 = self.tables.mytable + + self.assert_compile( + table1.update(). + where(table1.c.myid == 7). + values({table1.c.myid: 5}), + 'UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1', + checkparams={'myid': 5, 'myid_1': 7}) + + def test_update_3(self): + table1 = self.tables.mytable + + self.assert_compile( + update(table1, table1.c.myid == 7), + 'UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1', + params={'name': 'fred'}) + + def test_update_4(self): + table1 = self.tables.mytable + + self.assert_compile( + update(table1, values={table1.c.name: table1.c.myid}), + 'UPDATE mytable SET name=mytable.myid') + + def test_update_5(self): + table1 = self.tables.mytable + + self.assert_compile( + update(table1, + whereclause=table1.c.name == bindparam('crit'), + values={table1.c.name: 'hi'}), + 'UPDATE mytable SET name=:name WHERE mytable.name = :crit', + params={'crit': 'notthere'}, + checkparams={'crit': 'notthere', 'name': 'hi'}) + + def test_update_6(self): + table1 = self.tables.mytable + + self.assert_compile( + update(table1, + table1.c.myid == 12, + values={table1.c.name: table1.c.myid}), + 'UPDATE mytable ' + 'SET name=mytable.myid, description=:description ' + 'WHERE mytable.myid = :myid_1', + params={'description': 'test'}, + checkparams={'description': 'test', 'myid_1': 12}) + + def test_update_7(self): + table1 = self.tables.mytable + + self.assert_compile( + update(table1, table1.c.myid == 12, values={table1.c.myid: 9}), + 'UPDATE mytable ' + 'SET myid=:myid, description=:description ' + 'WHERE mytable.myid = :myid_1', + params={'myid_1': 12, 'myid': 9, 'description': 'test'}) + + def test_update_8(self): + table1 = self.tables.mytable + + self.assert_compile( + update(table1, table1.c.myid == 12), + 'UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1', + params={'myid': 18}, checkparams={'myid': 18, 'myid_1': 12}) + + def test_update_9(self): + table1 = self.tables.mytable + + s = table1.update(table1.c.myid == 12, values={table1.c.name: 'lala'}) + c = s.compile(column_keys=['id', 'name']) + eq_(str(s), str(c)) + + def test_update_10(self): + table1 = self.tables.mytable + + v1 = {table1.c.name: table1.c.myid} + v2 = {table1.c.name: table1.c.name + 'foo'} + self.assert_compile( + update(table1, table1.c.myid == 12, values=v1).values(v2), + 'UPDATE mytable ' + 'SET ' + 'name=(mytable.name || :name_1), ' + 'description=:description ' + 'WHERE mytable.myid = :myid_1', + params={'description': 'test'}) + + def test_update_11(self): + table1 = self.tables.mytable + + values = { + table1.c.name: table1.c.name + 'lala', + table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho')) + } + self.assert_compile(update(table1, + (table1.c.myid == func.hoho(4)) & + (table1.c.name == literal('foo') + + table1.c.name + literal('lala')), + values=values), + 'UPDATE mytable ' + 'SET ' + 'myid=do_stuff(mytable.myid, :param_1), ' + 'name=(mytable.name || :name_1) ' + 'WHERE ' + 'mytable.myid = hoho(:hoho_1) AND ' + 'mytable.name = :param_2 || mytable.name || :param_3') + + def test_prefix_with(self): + table1 = self.tables.mytable + + stmt = table1.update().\ + prefix_with('A', 'B', dialect='mysql').\ + prefix_with('C', 'D') + + self.assert_compile(stmt, + 'UPDATE C D mytable SET myid=:myid, name=:name, ' + 'description=:description') + + self.assert_compile(stmt, + 'UPDATE A B C D mytable SET myid=%s, name=%s, description=%s', + dialect=mysql.dialect()) + + def test_alias(self): + table1 = self.tables.mytable + talias1 = table1.alias('t1') + + self.assert_compile(update(talias1, talias1.c.myid == 7), + 'UPDATE mytable AS t1 ' + 'SET name=:name ' + 'WHERE t1.myid = :myid_1', + params={table1.c.name: 'fred'}) + + self.assert_compile(update(talias1, table1.c.myid == 7), + 'UPDATE mytable AS t1 ' + 'SET name=:name ' + 'FROM mytable ' + 'WHERE mytable.myid = :myid_1', + params={table1.c.name: 'fred'}) + + def test_update_to_expression(self): + """test update from an expression. + + this logic is triggered currently by a left side that doesn't + have a key. The current supported use case is updating the index + of a Postgresql ARRAY type. + + """ + table1 = self.tables.mytable + expr = func.foo(table1.c.myid) + assert not hasattr(expr, 'key') + self.assert_compile(table1.update().values({expr: 'bar'}), + 'UPDATE mytable SET foo(myid)=:param_1') + + +class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, + AssertsCompiledSQL): __dialect__ = 'default' run_create_tables = run_inserts = run_deletes = None def test_render_table(self): users, addresses = self.tables.users, self.tables.addresses + self.assert_compile( - users.update().\ - values(name='newname').\ - where(users.c.id==addresses.c.user_id).\ - where(addresses.c.email_address=='e1'), - "UPDATE users SET name=:name FROM addresses " - "WHERE users.id = addresses.user_id AND " - "addresses.email_address = :email_address_1", - checkparams={u'email_address_1': 'e1', 'name': 'newname'} - ) + users.update(). + values(name='newname'). + where(users.c.id == addresses.c.user_id). + where(addresses.c.email_address == 'e1'), + 'UPDATE users ' + 'SET name=:name FROM addresses ' + 'WHERE ' + 'users.id = addresses.user_id AND ' + 'addresses.email_address = :email_address_1', + checkparams={u'email_address_1': 'e1', 'name': 'newname'}) def test_render_multi_table(self): - users, addresses, dingalings = \ - self.tables.users, \ - self.tables.addresses, \ - self.tables.dingalings + users = self.tables.users + addresses = self.tables.addresses + dingalings = self.tables.dingalings + + checkparams = { + u'email_address_1': 'e1', + u'id_1': 2, + 'name': 'newname' + } + self.assert_compile( - users.update().\ - values(name='newname').\ - where(users.c.id==addresses.c.user_id).\ - where(addresses.c.email_address=='e1').\ - where(addresses.c.id==dingalings.c.address_id).\ - where(dingalings.c.id==2), - "UPDATE users SET name=:name FROM addresses, " - "dingalings WHERE users.id = addresses.user_id " - "AND addresses.email_address = :email_address_1 " - "AND addresses.id = dingalings.address_id AND " - "dingalings.id = :id_1", - checkparams={u'email_address_1': 'e1', u'id_1': 2, - 'name': 'newname'} - ) + users.update(). + values(name='newname'). + where(users.c.id == addresses.c.user_id). + where(addresses.c.email_address == 'e1'). + where(addresses.c.id == dingalings.c.address_id). + where(dingalings.c.id == 2), + 'UPDATE users ' + 'SET name=:name ' + 'FROM addresses, dingalings ' + 'WHERE ' + 'users.id = addresses.user_id AND ' + 'addresses.email_address = :email_address_1 AND ' + 'addresses.id = dingalings.address_id AND ' + 'dingalings.id = :id_1', + checkparams=checkparams) def test_render_table_mysql(self): users, addresses = self.tables.users, self.tables.addresses + self.assert_compile( - users.update().\ - values(name='newname').\ - where(users.c.id==addresses.c.user_id).\ - where(addresses.c.email_address=='e1'), - "UPDATE users, addresses SET users.name=%s " - "WHERE users.id = addresses.user_id AND " - "addresses.email_address = %s", + users.update(). + values(name='newname'). + where(users.c.id == addresses.c.user_id). + where(addresses.c.email_address == 'e1'), + 'UPDATE users, addresses ' + 'SET users.name=%s ' + 'WHERE ' + 'users.id = addresses.user_id AND ' + 'addresses.email_address = %s', checkparams={u'email_address_1': 'e1', 'name': 'newname'}, - dialect=mysql.dialect() - ) + dialect=mysql.dialect()) def test_render_subquery(self): users, addresses = self.tables.users, self.tables.addresses - subq = select([addresses.c.id, - addresses.c.user_id, - addresses.c.email_address]).\ - where(addresses.c.id==7).alias() + + checkparams = { + u'email_address_1': 'e1', + u'id_1': 7, + 'name': 'newname' + } + + cols = [ + addresses.c.id, + addresses.c.user_id, + addresses.c.email_address + ] + + subq = select(cols).where(addresses.c.id == 7).alias() self.assert_compile( - users.update().\ - values(name='newname').\ - where(users.c.id==subq.c.user_id).\ - where(subq.c.email_address=='e1'), - "UPDATE users SET name=:name FROM " - "(SELECT addresses.id AS id, addresses.user_id " - "AS user_id, addresses.email_address AS " - "email_address FROM addresses WHERE addresses.id = " - ":id_1) AS anon_1 WHERE users.id = anon_1.user_id " - "AND anon_1.email_address = :email_address_1", - checkparams={u'email_address_1': 'e1', - u'id_1': 7, 'name': 'newname'} - ) + users.update(). + values(name='newname'). + where(users.c.id == subq.c.user_id). + where(subq.c.email_address == 'e1'), + 'UPDATE users ' + 'SET name=:name FROM (' + 'SELECT ' + 'addresses.id AS id, ' + 'addresses.user_id AS user_id, ' + 'addresses.email_address AS email_address ' + 'FROM addresses ' + 'WHERE addresses.id = :id_1' + ') AS anon_1 ' + 'WHERE users.id = anon_1.user_id ' + 'AND anon_1.email_address = :email_address_1', + checkparams=checkparams) + class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): @testing.requires.update_from def test_exec_two_table(self): users, addresses = self.tables.users, self.tables.addresses + testing.db.execute( - addresses.update().\ - values(email_address=users.c.name).\ - where(users.c.id==addresses.c.user_id).\ - where(users.c.name=='ed') - ) - eq_( - testing.db.execute( - addresses.select().\ - order_by(addresses.c.id)).fetchall(), - [ - (1, 7, 'x', "jack@bean.com"), - (2, 8, 'x', "ed"), - (3, 8, 'x', "ed"), - (4, 8, 'x', "ed"), - (5, 9, 'x', "fred@fred.com") - ] - ) + addresses.update(). + values(email_address=users.c.name). + where(users.c.id == addresses.c.user_id). + where(users.c.name == 'ed')) + + expected = [ + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'x', 'ed'), + (3, 8, 'x', 'ed'), + (4, 8, 'x', 'ed'), + (5, 9, 'x', 'fred@fred.com')] + self._assert_addresses(addresses, expected) @testing.requires.update_from def test_exec_two_table_plus_alias(self): users, addresses = self.tables.users, self.tables.addresses - a1 = addresses.alias() + a1 = addresses.alias() testing.db.execute( - addresses.update().\ - values(email_address=users.c.name).\ - where(users.c.id==a1.c.user_id).\ - where(users.c.name=='ed').\ - where(a1.c.id==addresses.c.id) - ) - eq_( - testing.db.execute( - addresses.select().\ - order_by(addresses.c.id)).fetchall(), - [ - (1, 7, 'x', "jack@bean.com"), - (2, 8, 'x', "ed"), - (3, 8, 'x', "ed"), - (4, 8, 'x', "ed"), - (5, 9, 'x', "fred@fred.com") - ] + addresses.update(). + values(email_address=users.c.name). + where(users.c.id == a1.c.user_id). + where(users.c.name == 'ed'). + where(a1.c.id == addresses.c.id) ) + expected = [ + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'x', 'ed'), + (3, 8, 'x', 'ed'), + (4, 8, 'x', 'ed'), + (5, 9, 'x', 'fred@fred.com')] + self._assert_addresses(addresses, expected) + @testing.requires.update_from def test_exec_three_table(self): - users, addresses, dingalings = \ - self.tables.users, \ - self.tables.addresses, \ - self.tables.dingalings + users = self.tables.users + addresses = self.tables.addresses + dingalings = self.tables.dingalings + testing.db.execute( - addresses.update().\ - values(email_address=users.c.name).\ - where(users.c.id==addresses.c.user_id).\ - where(users.c.name=='ed'). - where(addresses.c.id==dingalings.c.address_id).\ - where(dingalings.c.id==1), - ) - eq_( - testing.db.execute( - addresses.select().order_by(addresses.c.id) - ).fetchall(), - [ - (1, 7, 'x', "jack@bean.com"), - (2, 8, 'x', "ed"), - (3, 8, 'x', "ed@bettyboop.com"), - (4, 8, 'x', "ed@lala.com"), - (5, 9, 'x', "fred@fred.com") - ] - ) + addresses.update(). + values(email_address=users.c.name). + where(users.c.id == addresses.c.user_id). + where(users.c.name == 'ed'). + where(addresses.c.id == dingalings.c.address_id). + where(dingalings.c.id == 1)) + + expected = [ + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'x', 'ed'), + (3, 8, 'x', 'ed@bettyboop.com'), + (4, 8, 'x', 'ed@lala.com'), + (5, 9, 'x', 'fred@fred.com')] + self._assert_addresses(addresses, expected) @testing.only_on('mysql', 'Multi table update') def test_exec_multitable(self): users, addresses = self.tables.users, self.tables.addresses + + values = { + addresses.c.email_address: users.c.name, + users.c.name: 'ed2' + } + testing.db.execute( - addresses.update().\ - values({ - addresses.c.email_address:users.c.name, - users.c.name:'ed2' - }).\ - where(users.c.id==addresses.c.user_id).\ - where(users.c.name=='ed') - ) - eq_( - testing.db.execute( - addresses.select().order_by(addresses.c.id)).fetchall(), - [ - (1, 7, 'x', "jack@bean.com"), - (2, 8, 'x', "ed"), - (3, 8, 'x', "ed"), - (4, 8, 'x', "ed"), - (5, 9, 'x', "fred@fred.com") - ] - ) - eq_( - testing.db.execute( - users.select().order_by(users.c.id)).fetchall(), - [ - (7, 'jack'), - (8, 'ed2'), - (9, 'fred'), - (10, 'chuck') - ] - ) + addresses.update(). + values(values). + where(users.c.id == addresses.c.user_id). + where(users.c.name == 'ed')) + + expected = [ + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'x', 'ed'), + (3, 8, 'x', 'ed'), + (4, 8, 'x', 'ed'), + (5, 9, 'x', 'fred@fred.com')] + self._assert_addresses(addresses, expected) + + expected = [ + (7, 'jack'), + (8, 'ed2'), + (9, 'fred'), + (10, 'chuck')] + self._assert_users(users, expected) -class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, fixtures.TablesTest): + def _assert_addresses(self, addresses, expected): + stmt = addresses.select().order_by(addresses.c.id) + eq_(testing.db.execute(stmt).fetchall(), expected) + + def _assert_users(self, users, expected): + stmt = users.select().order_by(users.c.id) + eq_(testing.db.execute(stmt).fetchall(), expected) + + +class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, + fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + test_needs_autoincrement=True), Column('name', String(30), nullable=False), - Column('some_update', String(30), onupdate="im the update") - ) + Column('some_update', String(30), onupdate='im the update')) Table('addresses', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False), - ) + Column('email_address', String(50), nullable=False)) @classmethod def fixtures(cls): return dict( - users = ( + users=( ('id', 'name', 'some_update'), (8, 'ed', 'value'), (9, 'fred', 'value'), ), - - addresses = ( + addresses=( ('id', 'user_id', 'email_address'), - (2, 8, "ed@wood.com"), - (3, 8, "ed@bettyboop.com"), - (4, 9, "fred@fred.com") + (2, 8, 'ed@wood.com'), + (3, 8, 'ed@bettyboop.com'), + (4, 9, 'fred@fred.com') ), ) @testing.only_on('mysql', 'Multi table update') def test_defaults_second_table(self): users, addresses = self.tables.users, self.tables.addresses + + values = { + addresses.c.email_address: users.c.name, + users.c.name: 'ed2' + } + ret = testing.db.execute( - addresses.update().\ - values({ - addresses.c.email_address:users.c.name, - users.c.name:'ed2' - }).\ - where(users.c.id==addresses.c.user_id).\ - where(users.c.name=='ed') - ) - eq_( - set(ret.prefetch_cols()), - set([users.c.some_update]) - ) - eq_( - testing.db.execute( - addresses.select().order_by(addresses.c.id)).fetchall(), - [ - (2, 8, "ed"), - (3, 8, "ed"), - (4, 9, "fred@fred.com") - ] - ) - eq_( - testing.db.execute( - users.select().order_by(users.c.id)).fetchall(), - [ - (8, 'ed2', 'im the update'), - (9, 'fred', 'value'), - ] - ) + addresses.update(). + values(values). + where(users.c.id == addresses.c.user_id). + where(users.c.name == 'ed')) + + eq_(set(ret.prefetch_cols()), set([users.c.some_update])) + + expected = [ + (2, 8, 'ed'), + (3, 8, 'ed'), + (4, 9, 'fred@fred.com')] + self._assert_addresses(addresses, expected) + + expected = [ + (8, 'ed2', 'im the update'), + (9, 'fred', 'value')] + self._assert_users(users, expected) @testing.only_on('mysql', 'Multi table update') def test_no_defaults_second_table(self): users, addresses = self.tables.users, self.tables.addresses + ret = testing.db.execute( - addresses.update().\ - values({ - 'email_address':users.c.name, - }).\ - where(users.c.id==addresses.c.user_id).\ - where(users.c.name=='ed') - ) - eq_( - ret.prefetch_cols(),[] - ) - eq_( - testing.db.execute( - addresses.select().order_by(addresses.c.id)).fetchall(), - [ - (2, 8, "ed"), - (3, 8, "ed"), - (4, 9, "fred@fred.com") - ] - ) - # users table not actually updated, - # so no onupdate - eq_( - testing.db.execute( - users.select().order_by(users.c.id)).fetchall(), - [ - (8, 'ed', 'value'), - (9, 'fred', 'value'), - ] - ) + addresses.update(). + values({'email_address': users.c.name}). + where(users.c.id == addresses.c.user_id). + where(users.c.name == 'ed')) + + eq_(ret.prefetch_cols(), []) + + expected = [ + (2, 8, 'ed'), + (3, 8, 'ed'), + (4, 9, 'fred@fred.com')] + self._assert_addresses(addresses, expected) + + # users table not actually updated, so no onupdate + expected = [ + (8, 'ed', 'value'), + (9, 'fred', 'value')] + self._assert_users(users, expected) + + def _assert_addresses(self, addresses, expected): + stmt = addresses.select().order_by(addresses.c.id) + eq_(testing.db.execute(stmt).fetchall(), expected) + + def _assert_users(self, users, expected): + stmt = users.select().order_by(users.c.id) + eq_(testing.db.execute(stmt).fetchall(), expected) |