diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-07-07 11:12:31 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-07-11 14:20:10 -0400 |
commit | aceefb508ccd0911f52ff0e50324b3fefeaa3f16 (patch) | |
tree | e57124d3ea8b0e2cd7fe1d3ad22170fa956bcafb /test/base/test_utils.py | |
parent | 5c16367ee78fa1a41d6b715152dcc58f45323d2e (diff) | |
download | sqlalchemy-aceefb508ccd0911f52ff0e50324b3fefeaa3f16.tar.gz |
Allow duplicate columns in from clauses and selectables
The :func:`.select` construct and related constructs now allow for
duplication of column labels and columns themselves in the columns clause,
mirroring exactly how column expressions were passed in. This allows
the tuples returned by an executed result to match what was SELECTed
for in the first place, which is how the ORM :class:`.Query` works, so
this establishes better cross-compatibility between the two constructs.
Additionally, it allows column-positioning-sensitive structures such as
UNIONs (i.e. :class:`.CompoundSelect`) to be more intuitively constructed
in those cases where a particular column might appear in more than one
place. To support this change, the :class:`.ColumnCollection` has been
revised to support duplicate columns as well as to allow integer index
access.
Fixes: #4753
Change-Id: Ie09a8116f05c367995c1e43623c51e07971d3bf0
Diffstat (limited to 'test/base/test_utils.py')
-rw-r--r-- | test/base/test_utils.py | 559 |
1 files changed, 447 insertions, 112 deletions
diff --git a/test/base/test_utils.py b/test/base/test_utils.py index d7e4deb28..4f073ebfe 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -9,7 +9,7 @@ from sqlalchemy import sql from sqlalchemy import testing from sqlalchemy import util from sqlalchemy.sql import column -from sqlalchemy.sql.base import SeparateKeyColumnCollection +from sqlalchemy.sql.base import DedupeColumnCollection from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -18,6 +18,8 @@ from sqlalchemy.testing import fails_if from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ from sqlalchemy.testing.util import gc_collect @@ -442,60 +444,110 @@ class ToListTest(fixtures.TestBase): ) -class SeparateKeysColumnCollectionTest( - testing.AssertsCompiledSQL, fixtures.TestBase -): - def test_in(self): - cc = SeparateKeyColumnCollection() - cc["kcol1"] = sql.column("col1") - cc["kcol2"] = sql.column("col2") - cc["kcol3"] = sql.column("col3") - assert "col1" not in cc - assert "kcol2" in cc +class ColumnCollectionCommon(testing.AssertsCompiledSQL): + def _assert_collection_integrity(self, coll): + eq_(coll._colset, set(c for k, c in coll._collection)) + d = {} + for k, col in coll._collection: + d.setdefault(k, col) + d.update({idx: col for idx, (k, col) in enumerate(coll._collection)}) + eq_(coll._index, d) + + def test_keys(self): + c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3") + c2.key = "foo" + cc = self._column_collection( + columns=[("c1", c1), ("foo", c2), ("c3", c3)] + ) + eq_(cc.keys(), ["c1", "foo", "c3"]) - def test_get(self): - c1, c2 = sql.column("col1"), sql.column("col2") - cc = SeparateKeyColumnCollection([("kcol1", c1), ("kcol2", c2)]) - is_(cc.kcol1, c1) - is_(cc.kcol2, c2) + ci = cc.as_immutable() + eq_(ci.keys(), ["c1", "foo", "c3"]) + + def test_key_index_error(self): + cc = self._column_collection( + columns=[ + ("col1", sql.column("col1")), + ("col2", sql.column("col2")), + ] + ) + assert_raises(KeyError, lambda: cc["foo"]) + assert_raises(KeyError, lambda: cc[object()]) + assert_raises(IndexError, lambda: cc[5]) - def test_all_cols(self): - c1, c2 = sql.column("col1"), sql.column("col2") - cc = SeparateKeyColumnCollection([("kcol1", c1), ("kcol2", c2)]) - eq_(cc._all_columns, [c1, c2]) + def test_contains_column(self): + c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3") + cc = self._column_collection(columns=[("c1", c1), ("c2", c2)]) + is_true(cc.contains_column(c1)) + is_false(cc.contains_column(c3)) -class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_in(self): - cc = sql.ColumnCollection() - cc.add(sql.column("col1")) - cc.add(sql.column("col2")) - cc.add(sql.column("col3")) + col1 = sql.column("col1") + cc = self._column_collection( + columns=[ + ("col1", col1), + ("col2", sql.column("col2")), + ("col3", sql.column("col3")), + ] + ) assert "col1" in cc assert "col2" in cc - try: - cc["col1"] in cc - assert False - except exc.ArgumentError as e: - eq_(str(e), "__contains__ requires a string argument") + assert_raises_message( + exc.ArgumentError, + "__contains__ requires a string argument", + lambda: col1 in cc, + ) def test_compare(self): - cc1 = sql.ColumnCollection() - cc2 = sql.ColumnCollection() - cc3 = sql.ColumnCollection() c1 = sql.column("col1") c2 = c1.label("col2") c3 = sql.column("col3") - cc1.add(c1) - cc2.add(c2) - cc3.add(c3) - assert (cc1 == cc2).compare(c1 == c2) - assert not (cc1 == cc3).compare(c2 == c3) - @testing.emits_warning("Column ") + is_true( + self._column_collection( + [("col1", c1), ("col2", c2), ("col3", c3)] + ).compare( + self._column_collection( + [("col1", c1), ("col2", c2), ("col3", c3)] + ) + ) + ) + is_false( + self._column_collection( + [("col1", c1), ("col2", c2), ("col3", c3)] + ).compare(self._column_collection([("col1", c1), ("col2", c2)])) + ) + + +class ColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): + def _column_collection(self, columns=None): + return sql.ColumnCollection(columns=columns) + + def test_separate_key_all_cols(self): + c1, c2 = sql.column("col1"), sql.column("col2") + cc = self._column_collection([("kcol1", c1), ("kcol2", c2)]) + eq_(cc._all_columns, [c1, c2]) + + def test_separate_key_get(self): + c1, c2 = sql.column("col1"), sql.column("col2") + cc = self._column_collection([("kcol1", c1), ("kcol2", c2)]) + is_(cc.kcol1, c1) + is_(cc.kcol2, c2) + + def test_separate_key_in(self): + cc = self._column_collection( + columns=[ + ("kcol1", sql.column("col1")), + ("kcol2", sql.column("col2")), + ("kcol3", sql.column("col3")), + ] + ) + assert "col1" not in cc + assert "kcol2" in cc + def test_dupes_add(self): - cc = sql.ColumnCollection() c1, c2a, c3, c2b = ( column("c1"), @@ -504,27 +556,198 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): column("c2"), ) + cc = sql.ColumnCollection() + cc.add(c1) - cc.add(c2a) + cc.add(c2a, "c2") cc.add(c3) cc.add(c2b) eq_(cc._all_columns, [c1, c2a, c3, c2b]) + eq_(list(cc), [c1, c2a, c3, c2b]) + eq_(cc.keys(), ["c1", "c2", "c3", "c2"]) + + assert cc.contains_column(c2a) + assert cc.contains_column(c2b) + + # this is deterministic + is_(cc["c2"], c2a) + + self._assert_collection_integrity(cc) + + ci = cc.as_immutable() + eq_(ci._all_columns, [c1, c2a, c3, c2b]) + eq_(list(ci), [c1, c2a, c3, c2b]) + eq_(ci.keys(), ["c1", "c2", "c3", "c2"]) + + def test_dupes_construct(self): + + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) + + cc = sql.ColumnCollection( + columns=[("c1", c1), ("c2", c2a), ("c3", c3), ("c2", c2b)] + ) + + eq_(cc._all_columns, [c1, c2a, c3, c2b]) + + eq_(list(cc), [c1, c2a, c3, c2b]) + eq_(cc.keys(), ["c1", "c2", "c3", "c2"]) + + assert cc.contains_column(c2a) + assert cc.contains_column(c2b) + + # this is deterministic + is_(cc["c2"], c2a) + + self._assert_collection_integrity(cc) + + ci = cc.as_immutable() + eq_(ci._all_columns, [c1, c2a, c3, c2b]) + eq_(list(ci), [c1, c2a, c3, c2b]) + eq_(ci.keys(), ["c1", "c2", "c3", "c2"]) + + def test_identical_dupe_construct(self): + + c1, c2, c3 = (column("c1"), column("c2"), column("c3")) + + cc = sql.ColumnCollection( + columns=[("c1", c1), ("c2", c2), ("c3", c3), ("c2", c2)] + ) + + eq_(cc._all_columns, [c1, c2, c3, c2]) + # for iter, c2a is replaced by c2b, ordering # is maintained in that way. ideally, iter would be # the same as the "_all_columns" collection. + eq_(list(cc), [c1, c2, c3, c2]) + + assert cc.contains_column(c2) + self._assert_collection_integrity(cc) + + ci = cc.as_immutable() + eq_(ci._all_columns, [c1, c2, c3, c2]) + eq_(list(ci), [c1, c2, c3, c2]) + + +class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): + def _column_collection(self, columns=None): + return DedupeColumnCollection(columns=columns) + + def test_separate_key_cols(self): + c1, c2 = sql.column("col1"), sql.column("col2") + assert_raises_message( + exc.ArgumentError, + "DedupeColumnCollection requires columns be under " + "the same key as their .key", + self._column_collection, + [("kcol1", c1), ("kcol2", c2)], + ) + + cc = self._column_collection() + assert_raises_message( + exc.ArgumentError, + "DedupeColumnCollection requires columns be under " + "the same key as their .key", + cc.add, + c1, + "kcol1", + ) + + def test_pickle_w_mutation(self): + c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3") + + c2.key = "foo" + + cc = self._column_collection(columns=[("c1", c1), ("foo", c2)]) + ci = cc.as_immutable() + + d = {"cc": cc, "ci": ci} + + for loads, dumps in picklers(): + dp = loads(dumps(d)) + + cp = dp["cc"] + cpi = dp["ci"] + self._assert_collection_integrity(cp) + self._assert_collection_integrity(cpi) + + assert cp._colset is cpi._colset + assert cp._index is cpi._index + assert cp._collection is cpi._collection + + cp.add(c3) + + eq_(cp.keys(), ["c1", "foo", "c3"]) + eq_(cpi.keys(), ["c1", "foo", "c3"]) + + assert cp.contains_column(c3) + assert cpi.contains_column(c3) + + def test_keys_after_replace(self): + c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3") + c2.key = "foo" + cc = self._column_collection( + columns=[("c1", c1), ("foo", c2), ("c3", c3)] + ) + eq_(cc.keys(), ["c1", "foo", "c3"]) + + c4 = sql.column("c3") + cc.replace(c4) + eq_(cc.keys(), ["c1", "foo", "c3"]) + self._assert_collection_integrity(cc) + + def test_dupes_add_dedupe(self): + cc = DedupeColumnCollection() + + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) + + cc.add(c1) + cc.add(c2a) + cc.add(c3) + cc.add(c2b) + + eq_(cc._all_columns, [c1, c2b, c3]) + eq_(list(cc), [c1, c2b, c3]) - assert cc.contains_column(c2a) + assert not cc.contains_column(c2a) assert cc.contains_column(c2b) + self._assert_collection_integrity(cc) - ci = cc.as_immutable() - eq_(ci._all_columns, [c1, c2a, c3, c2b]) - eq_(list(ci), [c1, c2b, c3]) + def test_dupes_construct_dedupe(self): - def test_identical_dupe_add(self): - cc = sql.ColumnCollection() + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) + + cc = DedupeColumnCollection( + columns=[("c1", c1), ("c2", c2a), ("c3", c3), ("c2", c2b)] + ) + + eq_(cc._all_columns, [c1, c2b, c3]) + + eq_(list(cc), [c1, c2b, c3]) + + assert not cc.contains_column(c2a) + assert cc.contains_column(c2b) + self._assert_collection_integrity(cc) + + def test_identical_dupe_add_dedupes(self): + cc = DedupeColumnCollection() c1, c2, c3 = (column("c1"), column("c2"), column("c3")) @@ -535,27 +758,43 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): eq_(cc._all_columns, [c1, c2, c3]) - self.assert_compile( - cc == [c1, c2, c3], "c1 = c1 AND c2 = c2 AND c3 = c3" - ) - # for iter, c2a is replaced by c2b, ordering # is maintained in that way. ideally, iter would be # the same as the "_all_columns" collection. eq_(list(cc), [c1, c2, c3]) assert cc.contains_column(c2) + self._assert_collection_integrity(cc) ci = cc.as_immutable() eq_(ci._all_columns, [c1, c2, c3]) eq_(list(ci), [c1, c2, c3]) - self.assert_compile( - ci == [c1, c2, c3], "c1 = c1 AND c2 = c2 AND c3 = c3" + def test_identical_dupe_construct_dedupes(self): + + c1, c2, c3 = (column("c1"), column("c2"), column("c3")) + + cc = DedupeColumnCollection( + columns=[("c1", c1), ("c2", c2), ("c3", c3), ("c2", c2)] ) + eq_(cc._all_columns, [c1, c2, c3]) + + # for iter, c2a is replaced by c2b, ordering + # is maintained in that way. ideally, iter would be + # the same as the "_all_columns" collection. + eq_(list(cc), [c1, c2, c3]) + + assert cc.contains_column(c2) + self._assert_collection_integrity(cc) + + ci = cc.as_immutable() + eq_(ci._all_columns, [c1, c2, c3]) + eq_(list(ci), [c1, c2, c3]) + def test_replace(self): - cc = sql.ColumnCollection() + cc = DedupeColumnCollection() + ci = cc.as_immutable() c1, c2a, c3, c2b = ( column("c1"), @@ -572,16 +811,49 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): eq_(cc._all_columns, [c1, c2b, c3]) eq_(list(cc), [c1, c2b, c3]) + is_(cc[1], c2b) assert not cc.contains_column(c2a) assert cc.contains_column(c2b) + self._assert_collection_integrity(cc) + eq_(ci._all_columns, [c1, c2b, c3]) + eq_(list(ci), [c1, c2b, c3]) + is_(ci[1], c2b) + + def test_replace_key_matches_name_of_another(self): + cc = DedupeColumnCollection() ci = cc.as_immutable() + + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c4"), + ) + c2b.key = "c2" + + cc.add(c1) + cc.add(c2a) + cc.add(c3) + + cc.replace(c2b) + + eq_(cc._all_columns, [c1, c2b, c3]) + eq_(list(cc), [c1, c2b, c3]) + is_(cc[1], c2b) + self._assert_collection_integrity(cc) + + assert not cc.contains_column(c2a) + assert cc.contains_column(c2b) + eq_(ci._all_columns, [c1, c2b, c3]) eq_(list(ci), [c1, c2b, c3]) + is_(ci[1], c2b) def test_replace_key_matches(self): - cc = sql.ColumnCollection() + cc = DedupeColumnCollection() + ci = cc.as_immutable() c1, c2a, c3, c2b = ( column("c1"), @@ -599,16 +871,21 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): assert not cc.contains_column(c2a) assert cc.contains_column(c2b) + is_(cc[1], c2b) + assert_raises(IndexError, lambda: cc[3]) + self._assert_collection_integrity(cc) eq_(cc._all_columns, [c1, c2b, c3]) eq_(list(cc), [c1, c2b, c3]) - ci = cc.as_immutable() eq_(ci._all_columns, [c1, c2b, c3]) eq_(list(ci), [c1, c2b, c3]) + is_(ci[1], c2b) + assert_raises(IndexError, lambda: ci[3]) def test_replace_name_matches(self): - cc = sql.ColumnCollection() + cc = DedupeColumnCollection() + ci = cc.as_immutable() c1, c2a, c3, c2b = ( column("c1"), @@ -628,14 +905,19 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): assert cc.contains_column(c2b) eq_(cc._all_columns, [c1, c2b, c3]) - eq_(list(cc), [c1, c3, c2b]) + eq_(list(cc), [c1, c2b, c3]) + eq_(len(cc), 3) + is_(cc[1], c2b) + self._assert_collection_integrity(cc) - ci = cc.as_immutable() eq_(ci._all_columns, [c1, c2b, c3]) - eq_(list(ci), [c1, c3, c2b]) + eq_(list(ci), [c1, c2b, c3]) + eq_(len(ci), 3) + is_(ci[1], c2b) def test_replace_no_match(self): - cc = sql.ColumnCollection() + cc = DedupeColumnCollection() + ci = cc.as_immutable() c1, c2, c3, c4 = column("c1"), column("c2"), column("c3"), column("c4") c4.key = "X" @@ -651,42 +933,102 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): eq_(cc._all_columns, [c1, c2, c3, c4]) eq_(list(cc), [c1, c2, c3, c4]) + is_(cc[3], c4) + self._assert_collection_integrity(cc) - ci = cc.as_immutable() eq_(ci._all_columns, [c1, c2, c3, c4]) eq_(list(ci), [c1, c2, c3, c4]) + is_(ci[3], c4) - def test_dupes_extend(self): - cc = sql.ColumnCollection() + def test_replace_switch_key_name(self): + c1 = column("id") + c2 = column("street") + c3 = column("user_id") - c1, c2a, c3, c2b = ( - column("c1"), - column("c2"), - column("c3"), - column("c2"), + cc = DedupeColumnCollection( + columns=[("id", c1), ("street", c2), ("user_id", c3)] ) - cc.add(c1) - cc.add(c2a) + # for replace col with different key than name, it necessarily + # removes two columns - cc.extend([c3, c2b]) + c4 = column("id") + c4.key = "street" - eq_(cc._all_columns, [c1, c2a, c3, c2b]) + cc.replace(c4) - # for iter, c2a is replaced by c2b, ordering - # is maintained in that way. ideally, iter would be - # the same as the "_all_columns" collection. - eq_(list(cc), [c1, c2b, c3]) + eq_(list(cc), [c4, c3]) + self._assert_collection_integrity(cc) - assert cc.contains_column(c2a) - assert cc.contains_column(c2b) + def test_remove(self): + + c1, c2, c3 = column("c1"), column("c2"), column("c3") + cc = DedupeColumnCollection( + columns=[("c1", c1), ("c2", c2), ("c3", c3)] + ) ci = cc.as_immutable() - eq_(ci._all_columns, [c1, c2a, c3, c2b]) - eq_(list(ci), [c1, c2b, c3]) - def test_dupes_update(self): - cc = sql.ColumnCollection() + eq_(cc._all_columns, [c1, c2, c3]) + eq_(list(cc), [c1, c2, c3]) + assert cc.contains_column(c2) + assert "c2" in cc + + eq_(ci._all_columns, [c1, c2, c3]) + eq_(list(ci), [c1, c2, c3]) + assert ci.contains_column(c2) + assert "c2" in ci + + cc.remove(c2) + + eq_(cc._all_columns, [c1, c3]) + eq_(list(cc), [c1, c3]) + is_(cc[0], c1) + is_(cc[1], c3) + assert not cc.contains_column(c2) + assert "c2" not in cc + self._assert_collection_integrity(cc) + + eq_(ci._all_columns, [c1, c3]) + eq_(list(ci), [c1, c3]) + is_(ci[0], c1) + is_(ci[1], c3) + assert not ci.contains_column(c2) + assert "c2" not in ci + + assert_raises(IndexError, lambda: ci[2]) + + def test_remove_doesnt_change_iteration(self): + + c1, c2, c3, c4, c5 = ( + column("c1"), + column("c2"), + column("c3"), + column("c4"), + column("c5"), + ) + + cc = DedupeColumnCollection( + columns=[ + ("c1", c1), + ("c2", c2), + ("c3", c3), + ("c4", c4), + ("c5", c5), + ] + ) + + for col in cc: + if col.name not in ["c1", "c2"]: + cc.remove(col) + + eq_(cc.keys(), ["c1", "c2"]) + eq_([c.name for c in cc], ["c1", "c2"]) + self._assert_collection_integrity(cc) + + def test_dupes_extend(self): + cc = DedupeColumnCollection() + ci = cc.as_immutable() c1, c2a, c3, c2b = ( column("c1"), @@ -698,20 +1040,29 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): cc.add(c1) cc.add(c2a) - cc.update([(c3.key, c3), (c2b.key, c2b)]) + cc.extend([c3, c2b]) # this should remove c2a - eq_(cc._all_columns, [c1, c2a, c3, c2b]) + eq_(cc._all_columns, [c1, c2b, c3]) + eq_(list(cc), [c1, c2b, c3]) + is_(cc[1], c2b) + is_(cc[2], c3) + assert_raises(IndexError, lambda: cc[3]) + self._assert_collection_integrity(cc) - assert cc.contains_column(c2a) + assert not cc.contains_column(c2a) assert cc.contains_column(c2b) - # for iter, c2a is replaced by c2b, ordering - # is maintained in that way. ideally, iter would be - # the same as the "_all_columns" collection. - eq_(list(cc), [c1, c2b, c3]) + eq_(ci._all_columns, [c1, c2b, c3]) + eq_(list(ci), [c1, c2b, c3]) + is_(ci[1], c2b) + is_(ci[2], c3) + assert_raises(IndexError, lambda: ci[3]) - def test_extend_existing(self): - cc = sql.ColumnCollection() + assert not ci.contains_column(c2a) + assert ci.contains_column(c2b) + + def test_extend_existing_maintains_ordering(self): + cc = DedupeColumnCollection() c1, c2, c3, c4, c5 = ( column("c1"), @@ -723,32 +1074,16 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): cc.extend([c1, c2]) eq_(cc._all_columns, [c1, c2]) + self._assert_collection_integrity(cc) cc.extend([c3]) eq_(cc._all_columns, [c1, c2, c3]) - cc.extend([c4, c2, c5]) + self._assert_collection_integrity(cc) - eq_(cc._all_columns, [c1, c2, c3, c4, c5]) - - def test_update_existing(self): - cc = sql.ColumnCollection() - - c1, c2, c3, c4, c5 = ( - column("c1"), - column("c2"), - column("c3"), - column("c4"), - column("c5"), - ) - - cc.update([("c1", c1), ("c2", c2)]) - eq_(cc._all_columns, [c1, c2]) - - cc.update([("c3", c3)]) - eq_(cc._all_columns, [c1, c2, c3]) - cc.update([("c4", c4), ("c2", c2), ("c5", c5)]) + cc.extend([c4, c2, c5]) eq_(cc._all_columns, [c1, c2, c3, c4, c5]) + self._assert_collection_integrity(cc) class LRUTest(fixtures.TestBase): |