diff options
74 files changed, 4511 insertions, 2021 deletions
diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..5d6c2bdc4 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +include=lib/sqlalchemy/* + +[report] +omit=lib/sqlalchemy/testing/*
\ No newline at end of file diff --git a/.gitignore b/.gitignore index c22e53c39..55066f843 100644 --- a/.gitignore +++ b/.gitignore @@ -6,10 +6,12 @@ /doc/build/output/ /dogpile_data/ *.orig +*,cover /.tox .venv *.egg-info .coverage +coverage.xml .*,cover *.class *.so diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 44a2add71..e3d9175cb 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -14,6 +14,101 @@ :version: 0.9.8 .. change:: + :tags: bug, sqlite + :versions: 1.0.0 + :tickets: 3211 + + When selecting from a UNION using an attached database file, + the pysqlite driver reports column names in cursor.description + as 'dbname.tablename.colname', instead of 'tablename.colname' as + it normally does for a UNION (note that it's supposed to just be + 'colname' for both, but we work around it). The column translation + logic here has been adjusted to retrieve the rightmost token, rather + than the second token, so it works in both cases. Workaround + courtesy Tony Roberts. + + .. change:: + :tags: bug, postgresql + :versions: 1.0.0 + :tickets: 3021 + + A revisit to this issue first patched in 0.9.5, apparently + psycopg2's ``.closed`` accessor is not as reliable as we assumed, + so we have added an explicit check for the exception messages + "SSL SYSCALL error: Bad file descriptor" and + "SSL SYSCALL error: EOF detected" when detecting an + is-disconnect scenario. We will continue to consult psycopg2's + connection.closed as a first check. + + .. change:: + :tags: bug, orm, engine + :versions: 1.0.0 + :tickets: 3197 + + Fixed bug that affected generally the same classes of event + as that of :ticket:`3199`, when the ``named=True`` parameter + would be used. Some events would fail to register, and others + would not invoke the event arguments correctly, generally in the + case of when an event was "wrapped" for adaption in some other way. + The "named" mechanics have been rearranged to not interfere with + the argument signature expected by internal wrapper functions. + + .. change:: + :tags: bug, declarative + :versions: 1.0.0 + :tickets: 3208 + + Fixed an unlikely race condition observed in some exotic end-user + setups, where the attempt to check for "duplicate class name" in + declarative would hit upon a not-totally-cleaned-up weak reference + related to some other class being removed; the check here now ensures + the weakref still references an object before calling upon it further. + + .. change:: + :tags: bug, orm + :versions: 1.0.0 + :tickets: 3199 + + Fixed bug that affected many classes of event, particularly + ORM events but also engine events, where the usual logic of + "de duplicating" a redundant call to :func:`.event.listen` + with the same arguments would fail, for those events where the + listener function is wrapped. An assertion would be hit within + registry.py. This assertion has now been integrated into the + deduplication check, with the added bonus of a simpler means + of checking deduplication across the board. + + .. change:: + :tags: bug, mssql + :versions: 1.0.0 + :tickets: 3151 + + Fixed the version string detection in the pymssql dialect to + work with Microsoft SQL Azure, which changes the word "SQL Server" + to "SQL Azure". + + .. change:: + :tags: bug, orm + :versions: 1.0.0 + :tickets: 3194 + + Fixed warning that would emit when a complex self-referential + primaryjoin contained functions, while at the same time remote_side + was specified; the warning would suggest setting "remote side". + It now only emits if remote_side isn't present. + + .. change:: + :tags: bug, ext + :versions: 1.0.0 + :tickets: 3191 + + Fixed bug in ordering list where the order of items would be + thrown off during a collection replace event, if the + reorder_on_append flag were set to True. The fix ensures that the + ordering list only impacts the list that is explicitly associated + with the object. + + .. change:: :tags: bug, sql :versions: 1.0.0 :tickets: 3180 diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst index 9c7f207cc..4d5ab1f06 100644 --- a/doc/build/changelog/changelog_10.rst +++ b/doc/build/changelog/changelog_10.rst @@ -22,6 +22,128 @@ on compatibility concerns, see :doc:`/changelog/migration_10`. .. change:: + :tags: bug, sql, engine + :tickets: 3215 + + Fixed bug where a "branched" connection, that is the kind you get + when you call :meth:`.Connection.connect`, would not share invalidation + status with the parent. The architecture of branching has been tweaked + a bit so that the branched connection defers to the parent for + all invalidation status and operations. + + .. change:: + :tags: bug, sql, engine + :tickets: 3190 + + Fixed bug where a "branched" connection, that is the kind you get + when you call :meth:`.Connection.connect`, would not share transaction + status with the parent. The architecture of branching has been tweaked + a bit so that the branched connection defers to the parent for + all transactional status and operations. + + .. change:: + :tags: bug, declarative + :tickets: 2670 + + A relationship set up with :class:`.declared_attr` on + a :class:`.AbstractConcreteBase` base class will now be configured + on the abstract base mapping automatically, in addition to being + set up on descendant concrete classes as usual. + + .. seealso:: + + :ref:`feature_3150` + + .. change:: + :tags: feature, declarative + :tickets: 3150 + + The :class:`.declared_attr` construct has newly improved + behaviors and features in conjunction with declarative. The + decorated function will now have access to the final column + copies present on the local mixin when invoked, and will also + be invoked exactly once for each mapped class, the returned result + being memoized. A new modifier :attr:`.declared_attr.cascading` + is added as well. + + .. seealso:: + + :ref:`feature_3150` + + .. change:: + :tags: feature, ext + :tickets: 3210 + + The :mod:`sqlalchemy.ext.automap` extension will now set + ``cascade="all, delete-orphan"`` automatically on a one-to-many + relationship/backref where the foreign key is detected as containing + one or more non-nullable columns. This argument is present in the + keywords passed to :func:`.automap.generate_relationship` in this + case and can still be overridden. Additionally, if the + :class:`.ForeignKeyConstraint` specifies ``ondelete="CASCADE"`` + for a non-nullable or ``ondelete="SET NULL"`` for a nullable set + of columns, the argument ``passive_deletes=True`` is also added to the + relationship. Note that not all backends support reflection of + ondelete, but backends that do include Postgresql and MySQL. + + .. change:: + :tags: feature, sql + :tickets: 3206 + + Added new method :meth:`.Select.with_statement_hint` and ORM + method :meth:`.Query.with_statement_hint` to support statement-level + hints that are not specific to a table. + + .. change:: + :tags: bug, sqlite + :tickets: 3203 + :pullreq: bitbucket:31 + + SQLite now supports reflection of unique constraints from + temp tables; previously, this would fail with a TypeError. + Pull request courtesy Johannes Erdfelt. + + .. seealso:: + + :ref:`change_3204` - changes regarding SQLite temporary + table and view reflection. + + .. change:: + :tags: bug, sqlite + :tickets: 3204 + + Added :meth:`.Inspector.get_temp_table_names` and + :meth:`.Inspector.get_temp_view_names`; currently, only the + SQLite and Oracle dialects support these methods. The return of + temporary table and view names has been **removed** from SQLite and + Oracle's version of :meth:`.Inspector.get_table_names` and + :meth:`.Inspector.get_view_names`; other database backends cannot + support this information (such as MySQL), and the scope of operation + is different in that the tables can be local to a session and + typically aren't supported in remote schemas. + + .. seealso:: + + :ref:`change_3204` + + .. change:: + :tags: feature, postgresql + :tickets: 2891 + :pullreq: github:128 + + Support has been added for reflection of materialized views + and foreign tables, as well as support for materialized views + within :meth:`.Inspector.get_view_names`, and a new method + :meth:`.PGInspector.get_foreign_table_names` available on the + Postgresql version of :class:`.Inspector`. Pull request courtesy + Rodrigo Menezes. + + .. seealso:: + + :ref:`feature_2891` + + + .. change:: :tags: feature, orm Added new event handlers :meth:`.AttributeEvents.init_collection` @@ -268,6 +390,11 @@ default, or a server-side default "eagerly" fetched via RETURNING. .. change:: + :tags: feature, oracle + + Added support for the Oracle table option ON COMMIT. + + .. change:: :tags: feature, postgresql :tickets: 2051 diff --git a/doc/build/changelog/migration_10.rst b/doc/build/changelog/migration_10.rst index 6a48b31fa..0e9dd8d7b 100644 --- a/doc/build/changelog/migration_10.rst +++ b/doc/build/changelog/migration_10.rst @@ -8,7 +8,7 @@ What's New in SQLAlchemy 1.0? undergoing maintenance releases as of May, 2014, and SQLAlchemy version 1.0, as of yet unreleased. - Document last updated: September 7, 2014 + Document last updated: September 25, 2014 Introduction ============ @@ -37,7 +37,7 @@ any SQL expression, in addition to integer values, as arguments. The ORM this is used to allow a bound parameter to be passed, which can be substituted with a value later:: - sel = select([table]).limit(bindparam('mylimit')).offset(bindparam('myoffset')) + sel = select([table]).limit(bindparam('mylimit')).offset(bindparam('myoffset')) Dialects which don't support non-integer LIMIT or OFFSET expressions may continue to not support this behavior; third party dialects may also need modification @@ -82,35 +82,35 @@ that a raw load of rows now populates ORM-based objects around 25% faster. Assuming a 1M row table, a script like the following illustrates the type of load that's improved the most:: - import time - from sqlalchemy import Integer, Column, create_engine, Table - from sqlalchemy.orm import Session - from sqlalchemy.ext.declarative import declarative_base + import time + from sqlalchemy import Integer, Column, create_engine, Table + from sqlalchemy.orm import Session + from sqlalchemy.ext.declarative import declarative_base - Base = declarative_base() + Base = declarative_base() - class Foo(Base): - __table__ = Table( - 'foo', Base.metadata, - Column('id', Integer, primary_key=True), - Column('a', Integer(), nullable=False), - Column('b', Integer(), nullable=False), - Column('c', Integer(), nullable=False), - ) + class Foo(Base): + __table__ = Table( + 'foo', Base.metadata, + Column('id', Integer, primary_key=True), + Column('a', Integer(), nullable=False), + Column('b', Integer(), nullable=False), + Column('c', Integer(), nullable=False), + ) - engine = create_engine( - 'mysql+mysqldb://scott:tiger@localhost/test', echo=True) + engine = create_engine( + 'mysql+mysqldb://scott:tiger@localhost/test', echo=True) - sess = Session(engine) + sess = Session(engine) - now = time.time() + now = time.time() - # avoid using all() so that we don't have the overhead of building - # a large list of full objects in memory - for obj in sess.query(Foo).yield_per(100).limit(1000000): - pass + # avoid using all() so that we don't have the overhead of building + # a large list of full objects in memory + for obj in sess.query(Foo).yield_per(100).limit(1000000): + pass - print("Total time: %d" % (time.time() - now)) + print("Total time: %d" % (time.time() - now)) Local MacBookPro results bench from 19 seconds for 0.9 down to 14 seconds for 1.0. The :meth:`.Query.yield_per` call is always a good idea when batching @@ -130,7 +130,7 @@ New KeyedTuple implementation dramatically faster We took a look into the :class:`.KeyedTuple` implementation in the hopes of improving queries like this:: - rows = sess.query(Foo.a, Foo.b, Foo.c).all() + rows = sess.query(Foo.a, Foo.b, Foo.c).all() The :class:`.KeyedTuple` class is used rather than Python's ``collections.namedtuple()``, because the latter has a very complex @@ -146,26 +146,26 @@ which scenario. In the "sweet spot", where we are both creating a good number of new types as well as fetching a good number of rows, the lightweight object totally smokes both namedtuple and KeyedTuple:: - ----------------- - size=10 num=10000 # few rows, lots of queries - namedtuple: 3.60302400589 # namedtuple falls over - keyedtuple: 0.255059957504 # KeyedTuple very fast - lw keyed tuple: 0.582715034485 # lw keyed trails right on KeyedTuple - ----------------- - size=100 num=1000 # <--- sweet spot - namedtuple: 0.365247011185 - keyedtuple: 0.24896979332 - lw keyed tuple: 0.0889317989349 # lw keyed blows both away! - ----------------- - size=10000 num=100 - namedtuple: 0.572599887848 - keyedtuple: 2.54251694679 - lw keyed tuple: 0.613876104355 - ----------------- - size=1000000 num=10 # few queries, lots of rows - namedtuple: 5.79669594765 # namedtuple very fast - keyedtuple: 28.856498003 # KeyedTuple falls over - lw keyed tuple: 6.74346804619 # lw keyed trails right on namedtuple + ----------------- + size=10 num=10000 # few rows, lots of queries + namedtuple: 3.60302400589 # namedtuple falls over + keyedtuple: 0.255059957504 # KeyedTuple very fast + lw keyed tuple: 0.582715034485 # lw keyed trails right on KeyedTuple + ----------------- + size=100 num=1000 # <--- sweet spot + namedtuple: 0.365247011185 + keyedtuple: 0.24896979332 + lw keyed tuple: 0.0889317989349 # lw keyed blows both away! + ----------------- + size=10000 num=100 + namedtuple: 0.572599887848 + keyedtuple: 2.54251694679 + lw keyed tuple: 0.613876104355 + ----------------- + size=1000000 num=10 # few queries, lots of rows + namedtuple: 5.79669594765 # namedtuple very fast + keyedtuple: 28.856498003 # KeyedTuple falls over + lw keyed tuple: 6.74346804619 # lw keyed trails right on namedtuple :ticket:`3176` @@ -195,27 +195,27 @@ them as duplicates. To illustrate, the following test script will show only ten warnings being emitted for ten of the parameter sets, out of a total of 1000:: - from sqlalchemy import create_engine, Unicode, select, cast - import random - import warnings + from sqlalchemy import create_engine, Unicode, select, cast + import random + import warnings - e = create_engine("sqlite://") + e = create_engine("sqlite://") - # Use the "once" filter (which is also the default for Python - # warnings). Exactly ten of these warnings will - # be emitted; beyond that, the Python warnings registry will accumulate - # new values as dupes of one of the ten existing. - warnings.filterwarnings("once") + # Use the "once" filter (which is also the default for Python + # warnings). Exactly ten of these warnings will + # be emitted; beyond that, the Python warnings registry will accumulate + # new values as dupes of one of the ten existing. + warnings.filterwarnings("once") - for i in range(1000): - e.execute(select([cast( - ('foo_%d' % random.randint(0, 1000000)).encode('ascii'), Unicode)])) + for i in range(1000): + e.execute(select([cast( + ('foo_%d' % random.randint(0, 1000000)).encode('ascii'), Unicode)])) The format of the warning here is:: - /path/lib/sqlalchemy/sql/sqltypes.py:186: SAWarning: Unicode type received - non-unicode bind param value 'foo_4852'. (this warning may be - suppressed after 10 occurrences) + /path/lib/sqlalchemy/sql/sqltypes.py:186: SAWarning: Unicode type received + non-unicode bind param value 'foo_4852'. (this warning may be + suppressed after 10 occurrences) :ticket:`3178` @@ -233,15 +233,15 @@ However, as these objects are class-bound descriptors, they must be accessed at the attribute. Below this is illustared using the :attr:`.Mapper.all_orm_descriptors` namespace:: - class SomeObject(Base): - # ... + class SomeObject(Base): + # ... - @hybrid_property - def some_prop(self): - return self.value + 5 + @hybrid_property + def some_prop(self): + return self.value + 5 - inspect(SomeObject).all_orm_descriptors.some_prop.info['foo'] = 'bar' + inspect(SomeObject).all_orm_descriptors.some_prop.info['foo'] = 'bar' It is also available as a constructor argument for all :class:`.SchemaItem` objects (e.g. :class:`.ForeignKey`, :class:`.UniqueConstraint` etc.) as well @@ -258,26 +258,26 @@ Change to single-table-inheritance criteria when using from_self(), count() Given a single-table inheritance mapping, such as:: - class Widget(Base): - __table__ = 'widget_table' + class Widget(Base): + __table__ = 'widget_table' - class FooWidget(Widget): - pass + class FooWidget(Widget): + pass Using :meth:`.Query.from_self` or :meth:`.Query.count` against a subclass would produce a subquery, but then add the "WHERE" criteria for subtypes to the outside:: - sess.query(FooWidget).from_self().all() + sess.query(FooWidget).from_self().all() rendering:: - SELECT - anon_1.widgets_id AS anon_1_widgets_id, - anon_1.widgets_type AS anon_1_widgets_type - FROM (SELECT widgets.id AS widgets_id, widgets.type AS widgets_type, - FROM widgets) AS anon_1 - WHERE anon_1.widgets_type IN (?) + SELECT + anon_1.widgets_id AS anon_1_widgets_id, + anon_1.widgets_type AS anon_1_widgets_type + FROM (SELECT widgets.id AS widgets_id, widgets.type AS widgets_type, + FROM widgets) AS anon_1 + WHERE anon_1.widgets_type IN (?) The issue with this is that if the inner query does not specify all columns, then we can't add the WHERE clause on the outside (it actually tries, @@ -286,27 +286,161 @@ apparently goes way back to 0.6.5 with the note "may need to make more adjustments to this". Well, those adjustments have arrived! So now the above query will render:: - SELECT - anon_1.widgets_id AS anon_1_widgets_id, - anon_1.widgets_type AS anon_1_widgets_type - FROM (SELECT widgets.id AS widgets_id, widgets.type AS widgets_type, - FROM widgets - WHERE widgets.type IN (?)) AS anon_1 + SELECT + anon_1.widgets_id AS anon_1_widgets_id, + anon_1.widgets_type AS anon_1_widgets_type + FROM (SELECT widgets.id AS widgets_id, widgets.type AS widgets_type, + FROM widgets + WHERE widgets.type IN (?)) AS anon_1 So that queries that don't include "type" will still work!:: - sess.query(FooWidget.id).count() + sess.query(FooWidget.id).count() Renders:: - SELECT count(*) AS count_1 - FROM (SELECT widgets.id AS widgets_id - FROM widgets - WHERE widgets.type IN (?)) AS anon_1 + SELECT count(*) AS count_1 + FROM (SELECT widgets.id AS widgets_id + FROM widgets + WHERE widgets.type IN (?)) AS anon_1 :ticket:`3177` +.. _feature_3150: + +Improvements to declarative mixins, ``@declared_attr`` and related features +---------------------------------------------------------------------------- + +The declarative system in conjunction with :class:`.declared_attr` has been +overhauled to support new capabilities. + +A function decorated with :class:`.declared_attr` is now called only **after** +any mixin-based column copies are generated. This means the function can +call upon mixin-established columns and will receive a reference to the correct +:class:`.Column` object:: + + class HasFooBar(object): + foobar = Column(Integer) + + @declared_attr + def foobar_prop(cls): + return column_property('foobar: ' + cls.foobar) + + class SomeClass(HasFooBar, Base): + __tablename__ = 'some_table' + id = Column(Integer, primary_key=True) + +Above, ``SomeClass.foobar_prop`` will be invoked against ``SomeClass``, +and ``SomeClass.foobar`` will be the final :class:`.Column` object that is +to be mapped to ``SomeClass``, as opposed to the non-copied object present +directly on ``HasFooBar``, even though the columns aren't mapped yet. + +The :class:`.declared_attr` function now **memoizes** the value +that's returned on a per-class basis, so that repeated calls to the same +attribute will return the same value. We can alter the example to illustrate +this:: + + class HasFooBar(object): + @declared_attr + def foobar(cls): + return Column(Integer) + + @declared_attr + def foobar_prop(cls): + return column_property('foobar: ' + cls.foobar) + + class SomeClass(HasFooBar, Base): + __tablename__ = 'some_table' + id = Column(Integer, primary_key=True) + +Previously, ``SomeClass`` would be mapped with one particular copy of +the ``foobar`` column, but the ``foobar_prop`` by calling upon ``foobar`` +a second time would produce a different column. The value of +``SomeClass.foobar`` is now memoized during declarative setup time, so that +even before the attribute is mapped by the mapper, the interim column +value will remain consistent no matter how many times the +:class:`.declared_attr` is called upon. + +The two behaviors above should help considerably with declarative definition +of many types of mapper properties that derive from other attributes, where +the :class:`.declared_attr` function is called upon from other +:class:`.declared_attr` functions locally present before the class is +actually mapped. + +For a pretty slim edge case where one wishes to build a declarative mixin +that establishes distinct columns per subclass, a new modifier +:attr:`.declared_attr.cascading` is added. With this modifier, the +decorated function will be invoked individually for each class in the +mapped inheritance hierarchy. While this is already the behavior for +special attributes such as ``__table_args__`` and ``__mapper_args__``, +for columns and other properties the behavior by default assumes that attribute +is affixed to the base class only, and just inherited from subclasses. +With :attr:`.declared_attr.cascading`, individual behaviors can be +applied:: + + class HasSomeAttribute(object): + @declared_attr.cascading + def some_id(cls): + if has_inherited_table(cls): + return Column(ForeignKey('myclass.id'), primary_key=True) + else: + return Column(Integer, primary_key=True) + + return Column('id', Integer, primary_key=True) + + class MyClass(HasSomeAttribute, Base): + "" + # ... + + class MySubClass(MyClass): + "" + # ... + +.. seealso:: + + :ref:`mixin_inheritance_columns` + +Finally, the :class:`.AbstractConcreteBase` class has been reworked +so that a relationship or other mapper property can be set up inline +on the abstract base:: + + from sqlalchemy import Column, Integer, ForeignKey + from sqlalchemy.orm import relationship + from sqlalchemy.ext.declarative import (declarative_base, declared_attr, + AbstractConcreteBase) + + Base = declarative_base() + + class Something(Base): + __tablename__ = u'something' + id = Column(Integer, primary_key=True) + + + class Abstract(AbstractConcreteBase, Base): + id = Column(Integer, primary_key=True) + + @declared_attr + def something_id(cls): + return Column(ForeignKey(Something.id)) + + @declared_attr + def something(cls): + return relationship(Something) + + + class Concrete(Abstract): + __tablename__ = u'cca' + __mapper_args__ = {'polymorphic_identity': 'cca', 'concrete': True} + + +The above mapping will set up a table ``cca`` with both an ``id`` and +a ``something_id`` column, and ``Concrete`` will also have a relationship +``something``. The new feature is that ``Abstract`` will also have an +independently configured relationship ``something`` that builds against +the polymorphic union of the base. + +:ticket:`3150` :ticket:`2670` :ticket:`3149` :ticket:`2952` :ticket:`3050` .. _bug_3188: @@ -319,67 +453,67 @@ as the "order by label" logic introduced in 0.9 (see :ref:`migration_1068`). Given a mapping like the following:: - class A(Base): - __tablename__ = 'a' + class A(Base): + __tablename__ = 'a' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True) - class B(Base): - __tablename__ = 'b' + class B(Base): + __tablename__ = 'b' - id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id')) + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey('a.id')) - A.b = column_property( - select([func.max(B.id)]).where(B.a_id == A.id).correlate(A) - ) + A.b = column_property( + select([func.max(B.id)]).where(B.a_id == A.id).correlate(A) + ) A simple scenario that included "A.b" twice would fail to render correctly:: - print sess.query(A, a1).order_by(a1.b) + print sess.query(A, a1).order_by(a1.b) This would order by the wrong column:: - SELECT a.id AS a_id, (SELECT max(b.id) AS max_1 FROM b - WHERE b.a_id = a.id) AS anon_1, a_1.id AS a_1_id, - (SELECT max(b.id) AS max_2 - FROM b WHERE b.a_id = a_1.id) AS anon_2 - FROM a, a AS a_1 ORDER BY anon_1 + SELECT a.id AS a_id, (SELECT max(b.id) AS max_1 FROM b + WHERE b.a_id = a.id) AS anon_1, a_1.id AS a_1_id, + (SELECT max(b.id) AS max_2 + FROM b WHERE b.a_id = a_1.id) AS anon_2 + FROM a, a AS a_1 ORDER BY anon_1 New output:: - SELECT a.id AS a_id, (SELECT max(b.id) AS max_1 - FROM b WHERE b.a_id = a.id) AS anon_1, a_1.id AS a_1_id, - (SELECT max(b.id) AS max_2 - FROM b WHERE b.a_id = a_1.id) AS anon_2 - FROM a, a AS a_1 ORDER BY anon_2 + SELECT a.id AS a_id, (SELECT max(b.id) AS max_1 + FROM b WHERE b.a_id = a.id) AS anon_1, a_1.id AS a_1_id, + (SELECT max(b.id) AS max_2 + FROM b WHERE b.a_id = a_1.id) AS anon_2 + FROM a, a AS a_1 ORDER BY anon_2 There were also many scenarios where the "order by" logic would fail to order by label, for example if the mapping were "polymorphic":: - class A(Base): - __tablename__ = 'a' + class A(Base): + __tablename__ = 'a' - id = Column(Integer, primary_key=True) - type = Column(String) + id = Column(Integer, primary_key=True) + type = Column(String) - __mapper_args__ = {'polymorphic_on': type, 'with_polymorphic': '*'} + __mapper_args__ = {'polymorphic_on': type, 'with_polymorphic': '*'} The order_by would fail to use the label, as it would be anonymized due to the polymorphic loading:: - SELECT a.id AS a_id, a.type AS a_type, (SELECT max(b.id) AS max_1 - FROM b WHERE b.a_id = a.id) AS anon_1 - FROM a ORDER BY (SELECT max(b.id) AS max_2 - FROM b WHERE b.a_id = a.id) + SELECT a.id AS a_id, a.type AS a_type, (SELECT max(b.id) AS max_1 + FROM b WHERE b.a_id = a.id) AS anon_1 + FROM a ORDER BY (SELECT max(b.id) AS max_2 + FROM b WHERE b.a_id = a.id) Now that the order by label tracks the anonymized label, this now works:: - SELECT a.id AS a_id, a.type AS a_type, (SELECT max(b.id) AS max_1 - FROM b WHERE b.a_id = a.id) AS anon_1 - FROM a ORDER BY anon_1 + SELECT a.id AS a_id, a.type AS a_type, (SELECT max(b.id) AS max_1 + FROM b WHERE b.a_id = a.id) AS anon_1 + FROM a ORDER BY anon_1 Included in these fixes are a variety of heisenbugs that could corrupt the state of an ``aliased()`` construct such that the labeling logic @@ -406,13 +540,13 @@ for :func:`.attributes.get_history` and related functions. Given an object with no state:: - >>> obj = Foo() + >>> obj = Foo() It has always been SQLAlchemy's behavior such that if we access a scalar or many-to-one attribute that was never set, it is returned as ``None``:: - >>> obj.someattr - None + >>> obj.someattr + None This value of ``None`` is in fact now part of the state of ``obj``, and is not unlike as though we had set the attribute explicitly, e.g. @@ -420,31 +554,31 @@ not unlike as though we had set the attribute explicitly, e.g. differently as far as history and events. It would not emit any attribute event, and additionally if we view history, we see this:: - >>> inspect(obj).attrs.someattr.history - History(added=(), unchanged=[None], deleted=()) # 0.9 and below + >>> inspect(obj).attrs.someattr.history + History(added=(), unchanged=[None], deleted=()) # 0.9 and below That is, it's as though the attribute were always ``None`` and were never changed. This is explicitly different from if we had set the attribute first instead:: - >>> obj = Foo() - >>> obj.someattr = None - >>> inspect(obj).attrs.someattr.history - History(added=[None], unchanged=(), deleted=()) # all versions + >>> obj = Foo() + >>> obj.someattr = None + >>> inspect(obj).attrs.someattr.history + History(added=[None], unchanged=(), deleted=()) # all versions The above means that the behavior of our "set" operation can be corrupted by the fact that the value was accessed via "get" earlier. In 1.0, this inconsistency has been resolved, by no longer actually setting anything when the default "getter" is used. - >>> obj = Foo() - >>> obj.someattr - None - >>> inspect(obj).attrs.someattr.history - History(added=(), unchanged=(), deleted=()) # 1.0 - >>> obj.someattr = None - >>> inspect(obj).attrs.someattr.history - History(added=[None], unchanged=(), deleted=()) + >>> obj = Foo() + >>> obj.someattr + None + >>> inspect(obj).attrs.someattr.history + History(added=(), unchanged=(), deleted=()) # 1.0 + >>> obj.someattr = None + >>> inspect(obj).attrs.someattr.history + History(added=[None], unchanged=(), deleted=()) The reason the above behavior hasn't had much impact is because the INSERT statement in relational databases considers a missing value to be @@ -482,17 +616,17 @@ with yield-per (subquery loading could be in theory, however). When this error is raised, the :func:`.lazyload` option can be sent with an asterisk:: - q = sess.query(Object).options(lazyload('*')).yield_per(100) + q = sess.query(Object).options(lazyload('*')).yield_per(100) or use :meth:`.Query.enable_eagerloads`:: - q = sess.query(Object).enable_eagerloads(False).yield_per(100) + q = sess.query(Object).enable_eagerloads(False).yield_per(100) The :func:`.lazyload` option has the advantage that additional many-to-one joined loader options can still be used:: - q = sess.query(Object).options( - lazyload('*'), joinedload("some_manytoone")).yield_per(100) + q = sess.query(Object).options( + lazyload('*'), joinedload("some_manytoone")).yield_per(100) .. _migration_migration_deprecated_orm_events: @@ -546,7 +680,7 @@ The unused ``result`` member is now removed:: .. seealso:: - :ref:`bundles` + :ref:`bundles` .. _migration_3008: @@ -565,12 +699,12 @@ As introduced in :ref:`feature_2976` from version 0.9, the behavior of join eager load will use a right-nested join. ``"nested"`` is now implied when using ``innerjoin=True``:: - query(User).options( - joinedload("orders", innerjoin=False).joinedload("items", innerjoin=True)) + query(User).options( + joinedload("orders", innerjoin=False).joinedload("items", innerjoin=True)) With the new default, this will render the FROM clause in the form:: - FROM users LEFT OUTER JOIN (orders JOIN items ON <onclause>) ON <onclause> + FROM users LEFT OUTER JOIN (orders JOIN items ON <onclause>) ON <onclause> That is, using a right-nested join for the INNER join so that the full result of ``users`` can be returned. The use of an INNER join is more efficient @@ -579,13 +713,13 @@ optimization parameter to take effect in all cases. To get the older behavior, use ``innerjoin="unnested"``:: - query(User).options( - joinedload("orders", innerjoin=False).joinedload("items", innerjoin="unnested")) + query(User).options( + joinedload("orders", innerjoin=False).joinedload("items", innerjoin="unnested")) This will avoid right-nested joins and chain the joins together using all OUTER joins despite the innerjoin directive:: - FROM users LEFT OUTER JOIN orders ON <onclause> LEFT OUTER JOIN items ON <onclause> + FROM users LEFT OUTER JOIN orders ON <onclause> LEFT OUTER JOIN items ON <onclause> As noted in the 0.9 notes, the only database backend that has difficulty with right-nested joins is SQLite; SQLAlchemy as of 0.9 converts a right-nested @@ -593,7 +727,7 @@ join into a subquery as a join target on SQLite. .. seealso:: - :ref:`feature_2976` - description of the feature as introduced in 0.9.4. + :ref:`feature_2976` - description of the feature as introduced in 0.9.4. :ticket:`3008` @@ -638,15 +772,15 @@ with SQL expressions into many functions, such as :meth:`.Select.where`, Note that by "SQL expressions" we mean a **full fragment of a SQL string**, such as:: - # the argument sent to where() is a full SQL expression - stmt = select([sometable]).where("somecolumn = 'value'") + # the argument sent to where() is a full SQL expression + stmt = select([sometable]).where("somecolumn = 'value'") and we are **not talking about string arguments**, that is, the normal behavior of passing string values that become parameterized:: - # This is a normal Core expression with a string argument - - # we aren't talking about this!! - stmt = select([sometable]).where(sometable.c.somecolumn == 'value') + # This is a normal Core expression with a string argument - + # we aren't talking about this!! + stmt = select([sometable]).where(sometable.c.somecolumn == 'value') The Core tutorial has long featured an example of the use of this technique, using a :func:`.select` construct where virtually all components of it @@ -660,25 +794,25 @@ So the change here is to encourage the user to qualify textual strings when composing SQL that is partially or fully composed from textual fragments. When composing a select as below:: - stmt = select(["a", "b"]).where("a = b").select_from("sometable") + stmt = select(["a", "b"]).where("a = b").select_from("sometable") The statement is built up normally, with all the same coercions as before. However, one will see the following warnings emitted:: - SAWarning: Textual column expression 'a' should be explicitly declared - with text('a'), or use column('a') for more specificity - (this warning may be suppressed after 10 occurrences) + SAWarning: Textual column expression 'a' should be explicitly declared + with text('a'), or use column('a') for more specificity + (this warning may be suppressed after 10 occurrences) - SAWarning: Textual column expression 'b' should be explicitly declared - with text('b'), or use column('b') for more specificity - (this warning may be suppressed after 10 occurrences) + SAWarning: Textual column expression 'b' should be explicitly declared + with text('b'), or use column('b') for more specificity + (this warning may be suppressed after 10 occurrences) - SAWarning: Textual SQL expression 'a = b' should be explicitly declared - as text('a = b') (this warning may be suppressed after 10 occurrences) + SAWarning: Textual SQL expression 'a = b' should be explicitly declared + as text('a = b') (this warning may be suppressed after 10 occurrences) - SAWarning: Textual SQL FROM expression 'sometable' should be explicitly - declared as text('sometable'), or use table('sometable') for more - specificity (this warning may be suppressed after 10 occurrences) + SAWarning: Textual SQL FROM expression 'sometable' should be explicitly + declared as text('sometable'), or use table('sometable') for more + specificity (this warning may be suppressed after 10 occurrences) These warnings attempt to show exactly where the issue is by displaying the parameters as well as where the string was received. @@ -688,14 +822,14 @@ one wishes the warnings to be exceptions, the `Python Warnings Filter <https://docs.python.org/2/library/warnings.html>`_ should be used:: - import warnings - warnings.simplefilter("error") # all warnings raise an exception + import warnings + warnings.simplefilter("error") # all warnings raise an exception Given the above warnings, our statement works just fine, but to get rid of the warnings we would rewrite our statement as follows:: - from sqlalchemy import select, text - stmt = select([ + from sqlalchemy import select, text + stmt = select([ text("a"), text("b") ]).where(text("a = b")).select_from(text("sometable")) @@ -703,10 +837,10 @@ to get rid of the warnings we would rewrite our statement as follows:: and as the warnings suggest, we can give our statement more specificity about the text if we use :func:`.column` and :func:`.table`:: - from sqlalchemy import select, text, column, table + from sqlalchemy import select, text, column, table - stmt = select([column("a"), column("b")]).\ - where(text("a = b")).select_from(table("sometable")) + stmt = select([column("a"), column("b")]).\ + where(text("a = b")).select_from(table("sometable")) Where note also that :func:`.table` and :func:`.column` can now be imported from "sqlalchemy" without the "sql" part. @@ -723,10 +857,10 @@ of this change we have enhanced its functionality. When we have a :func:`.select` or :class:`.Query` that refers to some column name or named label, we might want to GROUP BY and/or ORDER BY known columns or labels:: - stmt = select([ - user.c.name, - func.count(user.c.id).label("id_count") - ]).group_by("name").order_by("id_count") + stmt = select([ + user.c.name, + func.count(user.c.id).label("id_count") + ]).group_by("name").order_by("id_count") In the above statement we expect to see "ORDER BY id_count", as opposed to a re-statement of the function. The string argument given is actively @@ -734,24 +868,24 @@ matched to an entry in the columns clause during compilation, so the above statement would produce as we expect, without warnings (though note that the ``"name"`` expression has been resolved to ``users.name``!):: - SELECT users.name, count(users.id) AS id_count - FROM users GROUP BY users.name ORDER BY id_count + SELECT users.name, count(users.id) AS id_count + FROM users GROUP BY users.name ORDER BY id_count However, if we refer to a name that cannot be located, then we get the warning again, as below:: - stmt = select([ + stmt = select([ user.c.name, func.count(user.c.id).label("id_count") ]).order_by("some_label") The output does what we say, but again it warns us:: - SAWarning: Can't resolve label reference 'some_label'; converting to - text() (this warning may be suppressed after 10 occurrences) + SAWarning: Can't resolve label reference 'some_label'; converting to + text() (this warning may be suppressed after 10 occurrences) - SELECT users.name, count(users.id) AS id_count - FROM users ORDER BY some_label + SELECT users.name, count(users.id) AS id_count + FROM users ORDER BY some_label The above behavior applies to all those places where we might want to refer to a so-called "label reference"; ORDER BY and GROUP BY, but also within an @@ -761,7 +895,7 @@ Postgresql syntax). We can still specify any arbitrary expression for ORDER BY or others using :func:`.text`:: - stmt = select([users]).order_by(text("some special expression")) + stmt = select([users]).order_by(text("some special expression")) The upshot of the whole change is that SQLAlchemy now would like us to tell it when a string is sent that this string is explicitly @@ -822,7 +956,7 @@ data is needed. A :class:`.Table` can be set up for reflection by passing :paramref:`.Table.autoload_with` alone:: - my_table = Table('my_table', metadata, autoload_with=some_engine) + my_table = Table('my_table', metadata, autoload_with=some_engine) :ticket:`3027` @@ -855,15 +989,43 @@ The :func:`.inspect` method returns a :class:`.PGInspector` object in the case of Postgresql, which includes a new :meth:`.PGInspector.get_enums` method that returns information on all available ``ENUM`` types:: - from sqlalchemy import inspect, create_engine + from sqlalchemy import inspect, create_engine - engine = create_engine("postgresql+psycopg2://host/dbname") - insp = inspect(engine) - print(insp.get_enums()) + engine = create_engine("postgresql+psycopg2://host/dbname") + insp = inspect(engine) + print(insp.get_enums()) .. seealso:: - :meth:`.PGInspector.get_enums` + :meth:`.PGInspector.get_enums` + +.. _feature_2891: + +Postgresql Dialect reflects Materialized Views, Foreign Tables +-------------------------------------------------------------- + +Changes are as follows: + +* the :class:`Table` construct with ``autoload=True`` will now match a name + that exists in the database as a materialized view or foriegn table. + +* :meth:`.Inspector.get_view_names` will return plain and materialized view + names. + +* :meth:`.Inspector.get_table_names` does **not** change for Postgresql, it + continues to return only the names of plain tables. + +* A new method :meth:`.PGInspector.get_foreign_table_names` is added which + will return the names of tables that are specifically marked as "foreign" + in the Postgresql schema tables. + +The change to reflection involves adding ``'m'`` and ``'f'`` to the list +of qualifiers we use when querying ``pg_class.relkind``, but this change +is new in 1.0.0 to avoid any backwards-incompatible surprises for those +running 0.9 in production. + +:ticket:`2891` + MySQL internal "no such table" exceptions not passed to event handlers ---------------------------------------------------------------------- @@ -925,6 +1087,26 @@ when using ODBC to avoid this issue entirely. :ticket:`3182` +.. _change_3204: + +SQLite/Oracle have distinct methods for temporary table/view name reporting +--------------------------------------------------------------------------- + +The :meth:`.Inspector.get_table_names` and :meth:`.Inspector.get_view_names` +methods in the case of SQLite/Oracle would also return the names of temporary +tables and views, which is not provided by any other dialect (in the case +of MySQL at least it is not even possible). This logic has been moved +out to two new methods :meth:`.Inspector.get_temp_table_names` and +:meth:`.Inspector.get_temp_view_names`. + +Note that reflection of a specific named temporary table or temporary view, +either by ``Table('name', autoload=True)`` or via methods like +:meth:`.Inspector.get_columns` continues to function for most if not all +dialects. For SQLite specifically, there is a bug fix for UNIQUE constraint +reflection from temp tables as well, which is :ticket:`3203`. + +:ticket:`3204` + .. _change_2984: Drizzle Dialect is now an External Dialect diff --git a/doc/build/core/engines.rst b/doc/build/core/engines.rst index fb0320474..17ec9416c 100644 --- a/doc/build/core/engines.rst +++ b/doc/build/core/engines.rst @@ -151,9 +151,14 @@ For a relative file path, this requires three slashes:: # where <path> is relative: engine = create_engine('sqlite:///foo.db') -And for an absolute file path, *four* slashes are used:: +And for an absolute file path, the three slashes are followed by the absolute path:: + #Unix/Mac - 4 initial slashes in total engine = create_engine('sqlite:////absolute/path/to/foo.db') + #Windows + engine = create_engine('sqlite:///C:\\path\\to\\foo.db') + #Windows alternative using raw string + engine = create_engine(r'sqlite:///C:\path\to\foo.db') To use a SQLite ``:memory:`` database, specify an empty URL:: diff --git a/doc/build/dialects/index.rst b/doc/build/dialects/index.rst index da2699b2b..463f00612 100644 --- a/doc/build/dialects/index.rst +++ b/doc/build/dialects/index.rst @@ -50,6 +50,7 @@ Production Ready developed jointly by IBM and SQLAlchemy developers. * `redshift-sqlalchemy <https://pypi.python.org/pypi/redshift-sqlalchemy>`_ - driver for Amazon Redshift, adapts the existing Postgresql/psycopg2 driver. +* `sqlalchemy_exasol <https://github.com/blue-yonder/sqlalchemy_exasol>`_ - driver for EXASolution. * `sqlalchemy-sqlany <https://github.com/sqlanywhere/sqlalchemy-sqlany>`_ - driver for SAP Sybase SQL Anywhere, developed by SAP. * `sqlalchemy-monetdb <https://github.com/gijzelaerr/sqlalchemy-monetdb>`_ - driver for MonetDB. diff --git a/doc/build/orm/extensions/declarative.rst b/doc/build/orm/extensions/declarative.rst index 636bb451b..7d9e634b5 100644 --- a/doc/build/orm/extensions/declarative.rst +++ b/doc/build/orm/extensions/declarative.rst @@ -13,6 +13,7 @@ API Reference .. autofunction:: as_declarative .. autoclass:: declared_attr + :members: .. autofunction:: sqlalchemy.ext.declarative.api._declarative_constructor diff --git a/doc/build/orm/inheritance.rst b/doc/build/orm/inheritance.rst index 642f3420c..9f01a3e24 100644 --- a/doc/build/orm/inheritance.rst +++ b/doc/build/orm/inheritance.rst @@ -45,6 +45,12 @@ this column is to act as the **discriminator**, and stores a value which indicates the type of object represented within the row. The column may be of any datatype, though string and integer are the most common. +.. warning:: + + Currently, **only one discriminator column may be set**, typically + on the base-most class in the hierarchy. "Cascading" polymorphic columns + are not yet supported. + The discriminator column is only needed if polymorphic loading is desired, as is usually the case. It is not strictly necessary that it be present directly on the base mapped table, and can instead be defined on a diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 8f76336ae..b5a1bc566 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -63,7 +63,7 @@ class MSDialect_pymssql(MSDialect): def _get_server_version_info(self, connection): vers = connection.scalar("select @@version") m = re.match( - r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers) + r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers) if m: return tuple(int(x) for x in m.group(1, 2, 3, 4)) else: diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 81a9f1a95..837a498fb 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -213,6 +213,21 @@ is reflected and the type is reported as ``DATE``, the time-supporting examining the type of column for use in special Python translations or for migrating schemas to other database backends. +Oracle Table Options +------------------------- + +The CREATE TABLE phrase supports the following options with Oracle +in conjunction with the :class:`.Table` construct: + + +* ``ON COMMIT``:: + + Table( + "some_table", metadata, ..., + prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS') + +.. versionadded:: 1.0.0 + """ import re @@ -784,6 +799,16 @@ class OracleDDLCompiler(compiler.DDLCompiler): return super(OracleDDLCompiler, self).\ visit_create_index(create, include_schema=True) + def post_create_table(self, table): + table_opts = [] + opts = table.dialect_options['oracle'] + + if opts['on_commit']: + on_commit_options = opts['on_commit'].replace("_", " ").upper() + table_opts.append('\n ON COMMIT %s' % on_commit_options) + + return ''.join(table_opts) + class OracleIdentifierPreparer(compiler.IdentifierPreparer): @@ -842,7 +867,10 @@ class OracleDialect(default.DefaultDialect): reflection_options = ('oracle_resolve_synonyms', ) construct_arguments = [ - (sa_schema.Table, {"resolve_synonyms": False}) + (sa_schema.Table, { + "resolve_synonyms": False, + "on_commit": None + }) ] def __init__(self, @@ -1029,7 +1057,21 @@ class OracleDialect(default.DefaultDialect): "WHERE nvl(tablespace_name, 'no tablespace') NOT IN " "('SYSTEM', 'SYSAUX') " "AND OWNER = :owner " - "AND IOT_NAME IS NULL") + "AND IOT_NAME IS NULL " + "AND DURATION IS NULL") + cursor = connection.execute(s, owner=schema) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + schema = self.denormalize_name(self.default_schema_name) + s = sql.text( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN " + "('SYSTEM', 'SYSAUX') " + "AND OWNER = :owner " + "AND IOT_NAME IS NULL " + "AND DURATION IS NOT NULL") cursor = connection.execute(s, owner=schema) return [self.normalize_name(row[0]) for row in cursor] diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 575d2a6dd..b9a0d461b 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -401,6 +401,7 @@ The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX command, so it *must* be a valid index type for your version of PostgreSQL. + Special Reflection Options -------------------------- @@ -1679,6 +1680,19 @@ class PGInspector(reflection.Inspector): schema = schema or self.default_schema_name return self.dialect._load_enums(self.bind, schema) + def get_foreign_table_names(self, schema=None): + """Return a list of FOREIGN TABLE names. + + Behavior is similar to that of :meth:`.Inspector.get_table_names`, + except that the list is limited to those tables tha report a + ``relkind`` value of ``f``. + + .. versionadded:: 1.0.0 + + """ + schema = schema or self.default_schema_name + return self.dialect._get_foreign_table_names(self.bind, schema) + class CreateEnumType(schema._CreateDropBase): __visit_name__ = "create_enum_type" @@ -2024,7 +2038,7 @@ class PGDialect(default.DefaultDialect): FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE (%s) - AND c.relname = :table_name AND c.relkind in ('r','v') + AND c.relname = :table_name AND c.relkind in ('r', 'v', 'm', 'f') """ % schema_where_clause # Since we're binding to unicode, table_name and schema_name must be # unicode. @@ -2078,6 +2092,24 @@ class PGDialect(default.DefaultDialect): return [row[0] for row in result] @reflection.cache + def _get_foreign_table_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + result = connection.execute( + sql.text("SELECT relname FROM pg_class c " + "WHERE relkind = 'f' " + "AND '%s' = (select nspname from pg_namespace n " + "where n.oid = c.relnamespace) " % + current_schema, + typemap={'relname': sqltypes.Unicode} + ) + ) + return [row[0] for row in result] + + @reflection.cache def get_view_names(self, connection, schema=None, **kw): if schema is not None: current_schema = schema @@ -2086,7 +2118,7 @@ class PGDialect(default.DefaultDialect): s = """ SELECT relname FROM pg_class c - WHERE relkind = 'v' + WHERE relkind IN ('m', 'v') AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) """ % dict(schema=current_schema) @@ -2448,7 +2480,7 @@ class PGDialect(default.DefaultDialect): pg_attribute a on t.oid=a.attrelid and %s WHERE - t.relkind = 'r' + t.relkind IN ('r', 'v', 'f', 'm') and t.oid = :table_oid and ix.indisprimary = 'f' ORDER BY diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index e6450c97f..9dfd53e22 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -512,12 +512,14 @@ class PGDialect_psycopg2(PGDialect): def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): # check the "closed" flag. this might not be - # present on old psycopg2 versions + # present on old psycopg2 versions. Also, + # this flag doesn't actually help in a lot of disconnect + # situations, so don't rely on it. if getattr(connection, 'closed', False): return True - # legacy checks based on strings. the "closed" check - # above most likely obviates the need for any of these. + # checks based on strings. in the case that .closed + # didn't cut it, fall back onto these. str_e = str(e).partition("\n")[0] for msg in [ # these error messages from libpq: interfaces/libpq/fe-misc.c @@ -534,8 +536,10 @@ class PGDialect_psycopg2(PGDialect): # not sure where this path is originally from, it may # be obsolete. It really says "losed", not "closed". 'losed the connection unexpectedly', - # this can occur in newer SSL - 'connection has been closed unexpectedly' + # these can occur in newer SSL + 'connection has been closed unexpectedly', + 'SSL SYSCALL error: Bad file descriptor', + 'SSL SYSCALL error: EOF detected', ]: idx = str_e.find(msg) if idx >= 0 and '"' not in str_e[:idx]: diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index af793d275..335b35c94 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -713,10 +713,12 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): return self.execution_options.get("sqlite_raw_colnames", False) def _translate_colname(self, colname): - # adjust for dotted column names. SQLite in the case of UNION may - # store col names as "tablename.colname" in cursor.description + # adjust for dotted column names. SQLite + # in the case of UNION may store col names as + # "tablename.colname", or if using an attached database, + # "database.tablename.colname", in cursor.description if not self._preserve_raw_colnames and "." in colname: - return colname.split(".")[1], colname + return colname.split(".")[-1], colname else: return colname, None @@ -829,20 +831,26 @@ class SQLiteDialect(default.DefaultDialect): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) master = '%s.sqlite_master' % qschema - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) - rs = connection.execute(s) else: - try: - s = ("SELECT name FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - except exc.DBAPIError: - s = ("SELECT name FROM sqlite_master " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) + master = "sqlite_master" + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + return [row[0] for row in rs] + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + s = "SELECT name FROM sqlite_temp_master "\ + "WHERE type='table' ORDER BY name " + rs = connection.execute(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_temp_view_names(self, connection, **kw): + s = "SELECT name FROM sqlite_temp_master "\ + "WHERE type='view' ORDER BY name " + rs = connection.execute(s) return [row[0] for row in rs] @@ -869,20 +877,11 @@ class SQLiteDialect(default.DefaultDialect): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) master = '%s.sqlite_master' % qschema - s = ("SELECT name FROM %s " - "WHERE type='view' ORDER BY name") % (master,) - rs = connection.execute(s) else: - try: - s = ("SELECT name FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE type='view' ORDER BY name") - rs = connection.execute(s) - except exc.DBAPIError: - s = ("SELECT name FROM sqlite_master " - "WHERE type='view' ORDER BY name") - rs = connection.execute(s) + master = "sqlite_master" + s = ("SELECT name FROM %s " + "WHERE type='view' ORDER BY name") % (master,) + rs = connection.execute(s) return [row[0] for row in rs] @@ -1097,16 +1096,24 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_unique_constraints(self, connection, table_name, schema=None, **kw): - UNIQUE_SQL = """ - SELECT sql - FROM - sqlite_master - WHERE - type='table' AND - name=:table_name - """ - c = connection.execute(UNIQUE_SQL, table_name=table_name) - table_data = c.fetchone()[0] + try: + s = ("SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type = 'table'") % table_name + rs = connection.execute(s) + except exc.DBAPIError: + s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type = 'table'") % table_name + rs = connection.execute(s) + row = rs.fetchone() + if row is None: + # sqlite won't return the schema for the sqlite_master or + # sqlite_temp_master tables from this query. These tables + # don't have any unique constraints anyway. + return [] + table_data = row[0] UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)' return [ diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d2cc8890f..e5feda138 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -45,7 +45,7 @@ class Connection(Connectable): """ def __init__(self, engine, connection=None, close_with_result=False, - _branch=False, _execution_options=None, + _branch_from=None, _execution_options=None, _dispatch=None, _has_events=None): """Construct a new Connection. @@ -57,48 +57,80 @@ class Connection(Connectable): """ self.engine = engine self.dialect = engine.dialect - self.__connection = connection or engine.raw_connection() - self.__transaction = None - self.should_close_with_result = close_with_result - self.__savepoint_seq = 0 - self.__branch = _branch - self.__invalid = False - self.__can_reconnect = True - if _dispatch: + self.__branch_from = _branch_from + self.__branch = _branch_from is not None + + if _branch_from: + self.__connection = connection + self._execution_options = _execution_options + self._echo = _branch_from._echo + self.should_close_with_result = False self.dispatch = _dispatch - elif _has_events is None: - # if _has_events is sent explicitly as False, - # then don't join the dispatch of the engine; we don't - # want to handle any of the engine's events in that case. - self.dispatch = self.dispatch._join(engine.dispatch) - self._has_events = _has_events or ( - _has_events is None and engine._has_events) - - self._echo = self.engine._should_log_info() - if _execution_options: - self._execution_options =\ - engine._execution_options.union(_execution_options) + self._has_events = _branch_from._has_events else: + self.__connection = connection \ + if connection is not None else engine.raw_connection() + self.__transaction = None + self.__savepoint_seq = 0 + self.should_close_with_result = close_with_result + self.__invalid = False + self.__can_reconnect = True + self._echo = self.engine._should_log_info() + + if _has_events is None: + # if _has_events is sent explicitly as False, + # then don't join the dispatch of the engine; we don't + # want to handle any of the engine's events in that case. + self.dispatch = self.dispatch._join(engine.dispatch) + self._has_events = _has_events or ( + _has_events is None and engine._has_events) + + assert not _execution_options self._execution_options = engine._execution_options if self._has_events or self.engine._has_events: - self.dispatch.engine_connect(self, _branch) + self.dispatch.engine_connect(self, self.__branch) def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, and also whose close() method does nothing. - This is used to execute "sub" statements within a single execution, - usually an INSERT statement. + The Core uses this very sparingly, only in the case of + custom SQL default functions that are to be INSERTed as the + primary key of a row where we need to get the value back, so we have + to invoke it distinctly - this is a very uncommon case. + + Userland code accesses _branch() when the connect() or + contextual_connect() methods are called. The branched connection + acts as much as possible like the parent, except that it stays + connected when a close() event occurs. + """ + if self.__branch_from: + return self.__branch_from._branch() + else: + return self.engine._connection_cls( + self.engine, + self.__connection, + _branch_from=self, + _execution_options=self._execution_options, + _has_events=self._has_events, + _dispatch=self.dispatch) + + @property + def _root(self): + """return the 'root' connection. - return self.engine._connection_cls( - self.engine, - self.__connection, - _branch=True, - _has_events=self._has_events, - _dispatch=self.dispatch) + Returns 'self' if this connection is not a branch, else + returns the root connection from which we ultimately branched. + + """ + + if self.__branch_from: + return self.__branch_from + else: + return self def _clone(self): """Create a shallow copy of this Connection. @@ -224,7 +256,7 @@ class Connection(Connectable): def invalidated(self): """Return True if this connection was invalidated.""" - return self.__invalid + return self._root.__invalid @property def connection(self): @@ -236,6 +268,9 @@ class Connection(Connectable): return self._revalidate_connection() def _revalidate_connection(self): + if self.__branch_from: + return self.__branch_from._revalidate_connection() + if self.__can_reconnect and self.__invalid: if self.__transaction is not None: raise exc.InvalidRequestError( @@ -343,16 +378,17 @@ class Connection(Connectable): :ref:`pool_connection_invalidation` """ + if self.invalidated: return if self.closed: raise exc.ResourceClosedError("This Connection is closed") - if self._connection_is_valid: - self.__connection.invalidate(exception) - del self.__connection - self.__invalid = True + if self._root._connection_is_valid: + self._root.__connection.invalidate(exception) + del self._root.__connection + self._root.__invalid = True def detach(self): """Detach the underlying DB-API connection from its connection pool. @@ -415,6 +451,8 @@ class Connection(Connectable): :class:`.Engine`. """ + if self.__branch_from: + return self.__branch_from.begin() if self.__transaction is None: self.__transaction = RootTransaction(self) @@ -436,6 +474,9 @@ class Connection(Connectable): See also :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`. """ + if self.__branch_from: + return self.__branch_from.begin_nested() + if self.__transaction is None: self.__transaction = RootTransaction(self) else: @@ -459,6 +500,9 @@ class Connection(Connectable): """ + if self.__branch_from: + return self.__branch_from.begin_twophase(xid=xid) + if self.__transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " @@ -479,10 +523,11 @@ class Connection(Connectable): def in_transaction(self): """Return True if a transaction is in progress.""" - - return self.__transaction is not None + return self._root.__transaction is not None def _begin_impl(self, transaction): + assert not self.__branch_from + if self._echo: self.engine.logger.info("BEGIN (implicit)") @@ -497,6 +542,8 @@ class Connection(Connectable): self._handle_dbapi_exception(e, None, None, None, None) def _rollback_impl(self): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.rollback(self) @@ -516,6 +563,8 @@ class Connection(Connectable): self.__transaction = None def _commit_impl(self, autocommit=False): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.commit(self) @@ -532,6 +581,8 @@ class Connection(Connectable): self.__transaction = None def _savepoint_impl(self, name=None): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.savepoint(self, name) @@ -543,6 +594,8 @@ class Connection(Connectable): return name def _rollback_to_savepoint_impl(self, name, context): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.rollback_savepoint(self, name, context) @@ -551,6 +604,8 @@ class Connection(Connectable): self.__transaction = context def _release_savepoint_impl(self, name, context): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.release_savepoint(self, name, context) @@ -559,6 +614,8 @@ class Connection(Connectable): self.__transaction = context def _begin_twophase_impl(self, transaction): + assert not self.__branch_from + if self._echo: self.engine.logger.info("BEGIN TWOPHASE (implicit)") if self._has_events or self.engine._has_events: @@ -571,6 +628,8 @@ class Connection(Connectable): self.connection._reset_agent = transaction def _prepare_twophase_impl(self, xid): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.prepare_twophase(self, xid) @@ -579,6 +638,8 @@ class Connection(Connectable): self.engine.dialect.do_prepare_twophase(self, xid) def _rollback_twophase_impl(self, xid, is_prepared): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.rollback_twophase(self, xid, is_prepared) @@ -595,6 +656,8 @@ class Connection(Connectable): self.__transaction = None def _commit_twophase_impl(self, xid, is_prepared): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.commit_twophase(self, xid, is_prepared) @@ -610,8 +673,8 @@ class Connection(Connectable): self.__transaction = None def _autorollback(self): - if not self.in_transaction(): - self._rollback_impl() + if not self._root.in_transaction(): + self._root._rollback_impl() def close(self): """Close this :class:`.Connection`. @@ -632,13 +695,21 @@ class Connection(Connectable): and will allow no further operations. """ + if self.__branch_from: + try: + del self.__connection + except AttributeError: + pass + finally: + self.__can_reconnect = False + return try: conn = self.__connection except AttributeError: pass else: - if not self.__branch: - conn.close() + + conn.close() if conn._reset_agent is self.__transaction: conn._reset_agent = None @@ -993,8 +1064,8 @@ class Connection(Connectable): result.rowcount result.close(_autoclose_connection=False) - if self.__transaction is None and context.should_autocommit: - self._commit_impl(autocommit=True) + if context.should_autocommit and self._root.__transaction is None: + self._root._commit_impl(autocommit=True) if result.closed and self.should_close_with_result: self.close() diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 71df29cac..0ad2efae0 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -308,7 +308,15 @@ class Dialect(object): def get_table_names(self, connection, schema=None, **kw): """Return a list of table names for `schema`.""" - raise NotImplementedError + raise NotImplementedError() + + def get_temp_table_names(self, connection, schema=None, **kw): + """Return a list of temporary table names on the given connection, + if supported by the underlying backend. + + """ + + raise NotImplementedError() def get_view_names(self, connection, schema=None, **kw): """Return a list of all view names available in the database. @@ -319,6 +327,14 @@ class Dialect(object): raise NotImplementedError() + def get_temp_view_names(self, connection, schema=None, **kw): + """Return a list of temporary view names on the given connection, + if supported by the underlying backend. + + """ + + raise NotImplementedError() + def get_view_definition(self, connection, view_name, schema=None, **kw): """Return view definition. diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index cf1f2d3dd..838a5bdd2 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -201,6 +201,30 @@ class Inspector(object): tnames = list(topological.sort(tuples, tnames)) return tnames + def get_temp_table_names(self): + """return a list of temporary table names for the current bind. + + This method is unsupported by most dialects; currently + only SQLite implements it. + + .. versionadded:: 1.0.0 + + """ + return self.dialect.get_temp_table_names( + self.bind, info_cache=self.info_cache) + + def get_temp_view_names(self): + """return a list of temporary view names for the current bind. + + This method is unsupported by most dialects; currently + only SQLite implements it. + + .. versionadded:: 1.0.0 + + """ + return self.dialect.get_temp_view_names( + self.bind, info_cache=self.info_cache) + def get_table_options(self, table_name, schema=None, **kw): """Return a dictionary of options specified when the table of the given name was created. @@ -465,55 +489,83 @@ class Inspector(object): for col_d in self.get_columns( table_name, schema, **table.dialect_kwargs): found_table = True - orig_name = col_d['name'] - table.dispatch.column_reflect(self, table, col_d) + self._reflect_column( + table, col_d, include_columns, + exclude_columns, cols_by_orig_name) - name = col_d['name'] - if include_columns and name not in include_columns: - continue - if exclude_columns and name in exclude_columns: - continue + if not found_table: + raise exc.NoSuchTableError(table.name) - coltype = col_d['type'] + self._reflect_pk( + table_name, schema, table, cols_by_orig_name, exclude_columns) - col_kw = dict( - (k, col_d[k]) - for k in ['nullable', 'autoincrement', 'quote', 'info', 'key'] - if k in col_d - ) + self._reflect_fk( + table_name, schema, table, cols_by_orig_name, + exclude_columns, reflection_options) - colargs = [] - if col_d.get('default') is not None: - # the "default" value is assumed to be a literal SQL - # expression, so is wrapped in text() so that no quoting - # occurs on re-issuance. - colargs.append( - sa_schema.DefaultClause( - sql.text(col_d['default']), _reflected=True - ) - ) + self._reflect_indexes( + table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options) - if 'sequence' in col_d: - # TODO: mssql and sybase are using this. - seq = col_d['sequence'] - sequence = sa_schema.Sequence(seq['name'], 1, 1) - if 'start' in seq: - sequence.start = seq['start'] - if 'increment' in seq: - sequence.increment = seq['increment'] - colargs.append(sequence) + def _reflect_column( + self, table, col_d, include_columns, + exclude_columns, cols_by_orig_name): - cols_by_orig_name[orig_name] = col = \ - sa_schema.Column(name, coltype, *colargs, **col_kw) + orig_name = col_d['name'] - if col.key in table.primary_key: - col.primary_key = True - table.append_column(col) + table.dispatch.column_reflect(self, table, col_d) - if not found_table: - raise exc.NoSuchTableError(table.name) + # fetch name again as column_reflect is allowed to + # change it + name = col_d['name'] + if (include_columns and name not in include_columns) \ + or (exclude_columns and name in exclude_columns): + return + + coltype = col_d['type'] + col_kw = dict( + (k, col_d[k]) + for k in ['nullable', 'autoincrement', 'quote', 'info', 'key'] + if k in col_d + ) + + colargs = [] + if col_d.get('default') is not None: + # the "default" value is assumed to be a literal SQL + # expression, so is wrapped in text() so that no quoting + # occurs on re-issuance. + colargs.append( + sa_schema.DefaultClause( + sql.text(col_d['default']), _reflected=True + ) + ) + + if 'sequence' in col_d: + self._reflect_col_sequence(col_d, colargs) + + cols_by_orig_name[orig_name] = col = \ + sa_schema.Column(name, coltype, *colargs, **col_kw) + + if col.key in table.primary_key: + col.primary_key = True + table.append_column(col) + + def _reflect_col_sequence(self, col_d, colargs): + if 'sequence' in col_d: + # TODO: mssql and sybase are using this. + seq = col_d['sequence'] + sequence = sa_schema.Sequence(seq['name'], 1, 1) + if 'start' in seq: + sequence.start = seq['start'] + if 'increment' in seq: + sequence.increment = seq['increment'] + colargs.append(sequence) + + def _reflect_pk( + self, table_name, schema, table, + cols_by_orig_name, exclude_columns): pk_cons = self.get_pk_constraint( table_name, schema, **table.dialect_kwargs) if pk_cons: @@ -530,6 +582,9 @@ class Inspector(object): # its column collection table.primary_key._reload(pk_cols) + def _reflect_fk( + self, table_name, schema, table, cols_by_orig_name, + exclude_columns, reflection_options): fkeys = self.get_foreign_keys( table_name, schema, **table.dialect_kwargs) for fkey_d in fkeys: @@ -572,6 +627,10 @@ class Inspector(object): sa_schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True, **options)) + + def _reflect_indexes( + self, table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options): # Indexes indexes = self.get_indexes(table_name, schema) for index_d in indexes: diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index dba1063cf..be2a82208 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -319,14 +319,12 @@ class _ListenerCollection(RefCollection, _CompoundListener): registry._stored_in_collection_multi(self, other, to_associate) def insert(self, event_key, propagate): - if event_key._listen_fn not in self.listeners: - event_key.prepend_to_list(self, self.listeners) + if event_key.prepend_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) def append(self, event_key, propagate): - if event_key._listen_fn not in self.listeners: - event_key.append_to_list(self, self.listeners) + if event_key.append_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index ba2f671a3..5b422c401 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -71,13 +71,15 @@ def _stored_in_collection(event_key, owner): listen_ref = weakref.ref(event_key._listen_fn) if owner_ref in dispatch_reg: - assert dispatch_reg[owner_ref] == listen_ref - else: - dispatch_reg[owner_ref] = listen_ref + return False + + dispatch_reg[owner_ref] = listen_ref listener_to_key = _collection_to_key[owner_ref] listener_to_key[listen_ref] = key + return True + def _removed_from_collection(event_key, owner): key = event_key._key @@ -180,6 +182,17 @@ class _EventKey(object): def listen(self, *args, **kw): once = kw.pop("once", False) + named = kw.pop("named", False) + + target, identifier, fn = \ + self.dispatch_target, self.identifier, self._listen_fn + + dispatch_descriptor = getattr(target.dispatch, identifier) + + adjusted_fn = dispatch_descriptor._adjust_fn_spec(fn, named) + + self = self.with_wrapper(adjusted_fn) + if once: self.with_wrapper( util.only_once(self._listen_fn)).listen(*args, **kw) @@ -215,9 +228,6 @@ class _EventKey(object): dispatch_descriptor = getattr(target.dispatch, identifier) - fn = dispatch_descriptor._adjust_fn_spec(fn, named) - self = self.with_wrapper(fn) - if insert: dispatch_descriptor.\ for_modify(target.dispatch).insert(self, propagate) @@ -229,18 +239,20 @@ class _EventKey(object): def _listen_fn(self): return self.fn_wrap or self.fn - def append_value_to_list(self, owner, list_, value): - _stored_in_collection(self, owner) - list_.append(value) - def append_to_list(self, owner, list_): - _stored_in_collection(self, owner) - list_.append(self._listen_fn) + if _stored_in_collection(self, owner): + list_.append(self._listen_fn) + return True + else: + return False def remove_from_list(self, owner, list_): _removed_from_collection(self, owner) list_.remove(self._listen_fn) def prepend_to_list(self, owner, list_): - _stored_in_collection(self, owner) - list_.appendleft(self._listen_fn) + if _stored_in_collection(self, owner): + list_.appendleft(self._listen_fn) + return True + else: + return False diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 1ecec51b6..1ff35b8b0 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -470,7 +470,8 @@ class ConnectionEvents(event.Events): @classmethod def _listen(cls, event_key, retval=False): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn target._has_events = True diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 121285ab3..c11795d37 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -243,7 +243,26 @@ follows: one-to-many backref will be created on the referred class referring to this class. -4. The names of the relationships are determined using the +4. If any of the columns that are part of the :class:`.ForeignKeyConstraint` + are not nullable (e.g. ``nullable=False``), a + :paramref:`~.relationship.cascade` keyword argument + of ``all, delete-orphan`` will be added to the keyword arguments to + be passed to the relationship or backref. If the + :class:`.ForeignKeyConstraint` reports that + :paramref:`.ForeignKeyConstraint.ondelete` + is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable + set of columns, the option :paramref:`~.relationship.passive_deletes` + flag is set to ``True`` in the set of relationship keyword arguments. + Note that not all backends support reflection of ON DELETE. + + .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key + constraints when producing a one-to-many relationship and establish + a default cascade of ``all, delete-orphan`` if so; additionally, + if the constraint specifies :paramref:`.ForeignKeyConstraint.ondelete` + of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns, + the ``passive_deletes=True`` option is also added. + +5. The names of the relationships are determined using the :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and :paramref:`.AutomapBase.prepare.name_for_collection_relationship` callable functions. It is important to note that the default relationship @@ -252,18 +271,18 @@ follows: alternate class naming scheme, that's the name from which the relationship name will be derived. -5. The classes are inspected for an existing mapped property matching these +6. The classes are inspected for an existing mapped property matching these names. If one is detected on one side, but none on the other side, :class:`.AutomapBase` attempts to create a relationship on the missing side, then uses the :paramref:`.relationship.back_populates` parameter in order to point the new relationship to the other side. -6. In the usual case where no relationship is on either side, +7. In the usual case where no relationship is on either side, :meth:`.AutomapBase.prepare` produces a :func:`.relationship` on the "many-to-one" side and matches it to the other using the :paramref:`.relationship.backref` parameter. -7. Production of the :func:`.relationship` and optionally the :func:`.backref` +8. Production of the :func:`.relationship` and optionally the :func:`.backref` is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship` function, which can be supplied by the end-user in order to augment the arguments passed to :func:`.relationship` or :func:`.backref` or to @@ -877,6 +896,19 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, constraint ) + o2m_kws = {} + nullable = False not in set([fk.parent.nullable for fk in fks]) + if not nullable: + o2m_kws['cascade'] = "all, delete-orphan" + + if constraint.ondelete and \ + constraint.ondelete.lower() == "cascade": + o2m_kws['passive_deletes'] = True + else: + if constraint.ondelete and \ + constraint.ondelete.lower() == "set null": + o2m_kws['passive_deletes'] = True + create_backref = backref_name not in referred_cfg.properties if relationship_name not in map_config.properties: @@ -885,7 +917,8 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, automap_base, interfaces.ONETOMANY, backref, backref_name, referred_cls, local_cls, - collection_class=collection_class) + collection_class=collection_class, + **o2m_kws) else: backref_obj = None rel = generate_relationship(automap_base, @@ -916,7 +949,8 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, fk.parent for fk in constraint.elements], back_populates=relationship_name, - collection_class=collection_class) + collection_class=collection_class, + **o2m_kws) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index 3cbc85c0c..2b611252a 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -873,8 +873,7 @@ the method without the need to copy it. Columns generated by :class:`~.declared_attr` can also be referenced by ``__mapper_args__`` to a limited degree, currently -by ``polymorphic_on`` and ``version_id_col``, by specifying the -classdecorator itself into the dictionary - the declarative extension +by ``polymorphic_on`` and ``version_id_col``; the declarative extension will resolve them at class construction time:: class MyMixin: @@ -889,7 +888,6 @@ will resolve them at class construction time:: id = Column(Integer, primary_key=True) - Mixing in Relationships ~~~~~~~~~~~~~~~~~~~~~~~ @@ -922,6 +920,7 @@ reference a common target class via many-to-one:: __tablename__ = 'target' id = Column(Integer, primary_key=True) + Using Advanced Relationship Arguments (e.g. ``primaryjoin``, etc.) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -1004,6 +1003,24 @@ requirement so that no reliance on copying is needed:: class Something(SomethingMixin, Base): __tablename__ = "something" +The :func:`.column_property` or other construct may refer +to other columns from the mixin. These are copied ahead of time before +the :class:`.declared_attr` is invoked:: + + class SomethingMixin(object): + x = Column(Integer) + + y = Column(Integer) + + @declared_attr + def x_plus_y(cls): + return column_property(cls.x + cls.y) + + +.. versionchanged:: 1.0.0 mixin columns are copied to the final mapped class + so that :class:`.declared_attr` methods can access the actual column + that will be mapped. + Mixing in Association Proxy and Other Attributes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1087,19 +1104,20 @@ and ``TypeB`` classes. Controlling table inheritance with mixins ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The ``__tablename__`` attribute in conjunction with the hierarchy of -classes involved in a declarative mixin scenario controls what type of -table inheritance, if any, -is configured by the declarative extension. +The ``__tablename__`` attribute may be used to provide a function that +will determine the name of the table used for each class in an inheritance +hierarchy, as well as whether a class has its own distinct table. -If the ``__tablename__`` is computed by a mixin, you may need to -control which classes get the computed attribute in order to get the -type of table inheritance you require. +This is achieved using the :class:`.declared_attr` indicator in conjunction +with a method named ``__tablename__()``. Declarative will always +invoke :class:`.declared_attr` for the special names +``__tablename__``, ``__mapper_args__`` and ``__table_args__`` +function **for each mapped class in the hierarchy**. The function therefore +needs to expect to receive each class individually and to provide the +correct answer for each. -For example, if you had a mixin that computes ``__tablename__`` but -where you wanted to use that mixin in a single table inheritance -hierarchy, you can explicitly specify ``__tablename__`` as ``None`` to -indicate that the class should not have a table mapped:: +For example, to create a mixin that gives every class a simple table +name based on class name:: from sqlalchemy.ext.declarative import declared_attr @@ -1118,15 +1136,10 @@ indicate that the class should not have a table mapped:: __mapper_args__ = {'polymorphic_identity': 'engineer'} primary_language = Column(String(50)) -Alternatively, you can make the mixin intelligent enough to only -return a ``__tablename__`` in the event that no table is already -mapped in the inheritance hierarchy. To help with this, a -:func:`~sqlalchemy.ext.declarative.has_inherited_table` helper -function is provided that returns ``True`` if a parent class already -has a mapped table. - -As an example, here's a mixin that will only allow single table -inheritance:: +Alternatively, we can modify our ``__tablename__`` function to return +``None`` for subclasses, using :func:`.has_inherited_table`. This has +the effect of those subclasses being mapped with single table inheritance +agaisnt the parent:: from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import has_inherited_table @@ -1147,6 +1160,64 @@ inheritance:: primary_language = Column(String(50)) __mapper_args__ = {'polymorphic_identity': 'engineer'} +.. _mixin_inheritance_columns: + +Mixing in Columns in Inheritance Scenarios +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In constrast to how ``__tablename__`` and other special names are handled when +used with :class:`.declared_attr`, when we mix in columns and properties (e.g. +relationships, column properties, etc.), the function is +invoked for the **base class only** in the hierarchy. Below, only the +``Person`` class will receive a column +called ``id``; the mapping will fail on ``Engineer``, which is not given +a primary key:: + + class HasId(object): + @declared_attr + def id(cls): + return Column('id', Integer, primary_key=True) + + class Person(HasId, Base): + __tablename__ = 'person' + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __tablename__ = 'engineer' + primary_language = Column(String(50)) + __mapper_args__ = {'polymorphic_identity': 'engineer'} + +It is usually the case in joined-table inheritance that we want distinctly +named columns on each subclass. However in this case, we may want to have +an ``id`` column on every table, and have them refer to each other via +foreign key. We can achieve this as a mixin by using the +:attr:`.declared_attr.cascading` modifier, which indicates that the +function should be invoked **for each class in the hierarchy**, just like +it does for ``__tablename__``:: + + class HasId(object): + @declared_attr.cascading + def id(cls): + if has_inherited_table(cls): + return Column('id', + Integer, + ForeignKey('person.id'), primary_key=True) + else: + return Column('id', Integer, primary_key=True) + + class Person(HasId, Base): + __tablename__ = 'person' + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __tablename__ = 'engineer' + primary_language = Column(String(50)) + __mapper_args__ = {'polymorphic_identity': 'engineer'} + + +.. versionadded:: 1.0.0 added :attr:`.declared_attr.cascading`. Combining Table/Mapper Arguments from Multiple Mixins ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index daf8bffb5..e84b21ad2 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -8,12 +8,13 @@ from ...schema import Table, MetaData -from ...orm import synonym as _orm_synonym, mapper,\ +from ...orm import synonym as _orm_synonym, \ comparable_property,\ - interfaces, properties + interfaces, properties, attributes from ...orm.util import polymorphic_union from ...orm.base import _mapper_or_none -from ...util import OrderedDict +from ...util import OrderedDict, hybridmethod, hybridproperty +from ... import util from ... import exc import weakref @@ -21,7 +22,6 @@ from .base import _as_declarative, \ _declarative_constructor,\ _DeferredMapperConfig, _add_attribute from .clsregistry import _class_resolver -from . import clsregistry def instrument_declarative(cls, registry, metadata): @@ -157,12 +157,98 @@ class declared_attr(interfaces._MappedAttribute, property): """ - def __init__(self, fget, *arg, **kw): - super(declared_attr, self).__init__(fget, *arg, **kw) + def __init__(self, fget, cascading=False): + super(declared_attr, self).__init__(fget) self.__doc__ = fget.__doc__ + self._cascading = cascading def __get__(desc, self, cls): - return desc.fget(cls) + # use the ClassManager for memoization of values. This is better than + # adding yet another attribute onto the class, or using weakrefs + # here which are slow and take up memory. It also allows us to + # warn for non-mapped use of declared_attr. + + manager = attributes.manager_of_class(cls) + if manager is None: + util.warn( + "Unmanaged access of declarative attribute %s from " + "non-mapped class %s" % + (desc.fget.__name__, cls.__name__)) + return desc.fget(cls) + try: + reg = manager.info['declared_attr_reg'] + except KeyError: + raise exc.InvalidRequestError( + "@declared_attr called outside of the " + "declarative mapping process; is declarative_base() being " + "used correctly?") + + if desc in reg: + return reg[desc] + else: + reg[desc] = obj = desc.fget(cls) + return obj + + @hybridmethod + def _stateful(cls, **kw): + return _stateful_declared_attr(**kw) + + @hybridproperty + def cascading(cls): + """Mark a :class:`.declared_attr` as cascading. + + This is a special-use modifier which indicates that a column + or MapperProperty-based declared attribute should be configured + distinctly per mapped subclass, within a mapped-inheritance scenario. + + Below, both MyClass as well as MySubClass will have a distinct + ``id`` Column object established:: + + class HasSomeAttribute(object): + @declared_attr.cascading + def some_id(cls): + if has_inherited_table(cls): + return Column( + ForeignKey('myclass.id'), primary_key=True) + else: + return Column(Integer, primary_key=True) + + return Column('id', Integer, primary_key=True) + + class MyClass(HasSomeAttribute, Base): + "" + # ... + + class MySubClass(MyClass): + "" + # ... + + The behavior of the above configuration is that ``MySubClass`` + will refer to both its own ``id`` column as well as that of + ``MyClass`` underneath the attribute named ``some_id``. + + .. seealso:: + + :ref:`declarative_inheritance` + + :ref:`mixin_inheritance_columns` + + + """ + return cls._stateful(cascading=True) + + +class _stateful_declared_attr(declared_attr): + def __init__(self, **kw): + self.kw = kw + + def _stateful(self, **kw): + new_kw = self.kw.copy() + new_kw.update(kw) + return _stateful_declared_attr(**new_kw) + + def __call__(self, fn): + return declared_attr(fn, **self.kw) def declarative_base(bind=None, metadata=None, mapper=None, cls=object, @@ -349,9 +435,11 @@ class AbstractConcreteBase(ConcreteBase): ``__declare_last__()`` function, which is essentially a hook for the :meth:`.after_configured` event. - :class:`.AbstractConcreteBase` does not produce a mapped - table for the class itself. Compare to :class:`.ConcreteBase`, - which does. + :class:`.AbstractConcreteBase` does produce a mapped class + for the base class, however it is not persisted to any table; it + is instead mapped directly to the "polymorphic" selectable directly + and is only used for selecting. Compare to :class:`.ConcreteBase`, + which does create a persisted table for the base class. Example:: @@ -365,20 +453,72 @@ class AbstractConcreteBase(ConcreteBase): employee_id = Column(Integer, primary_key=True) name = Column(String(50)) manager_data = Column(String(40)) + __mapper_args__ = { - 'polymorphic_identity':'manager', - 'concrete':True} + 'polymorphic_identity':'manager', + 'concrete':True} + + The abstract base class is handled by declarative in a special way; + at class configuration time, it behaves like a declarative mixin + or an ``__abstract__`` base class. Once classes are configured + and mappings are produced, it then gets mapped itself, but + after all of its decscendants. This is a very unique system of mapping + not found in any other SQLAlchemy system. + + Using this approach, we can specify columns and properties + that will take place on mapped subclasses, in the way that + we normally do as in :ref:`declarative_mixins`:: + + class Company(Base): + __tablename__ = 'company' + id = Column(Integer, primary_key=True) + + class Employee(AbstractConcreteBase, Base): + employee_id = Column(Integer, primary_key=True) + + @declared_attr + def company_id(cls): + return Column(ForeignKey('company.id')) + + @declared_attr + def company(cls): + return relationship("Company") + + class Manager(Employee): + __tablename__ = 'manager' + + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + When we make use of our mappings however, both ``Manager`` and + ``Employee`` will have an independently usable ``.company`` attribute:: + + session.query(Employee).filter(Employee.company.has(id=5)) + + .. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase` + have been reworked to support relationships established directly + on the abstract base, without any special configurational steps. + """ - __abstract__ = True + __no_table__ = True @classmethod def __declare_first__(cls): - if hasattr(cls, '__mapper__'): + cls._sa_decl_prepare_nocascade() + + @classmethod + def _sa_decl_prepare_nocascade(cls): + if getattr(cls, '__mapper__', None): return - clsregistry.add_class(cls.__name__, cls) + to_map = _DeferredMapperConfig.config_for_cls(cls) + # can't rely on 'self_and_descendants' here # since technically an immediate subclass # might not be mapped, but a subclass @@ -392,7 +532,18 @@ class AbstractConcreteBase(ConcreteBase): if mn is not None: mappers.append(mn) pjoin = cls._create_polymorphic_union(mappers) - cls.__mapper__ = m = mapper(cls, pjoin, polymorphic_on=pjoin.c.type) + + to_map.local_table = pjoin + + m_args = to_map.mapper_args_fn or dict + + def mapper_args(): + args = m_args() + args['polymorphic_on'] = pjoin.c.type + return args + to_map.mapper_args_fn = mapper_args + + m = to_map.map() for scls in cls.__subclasses__(): sm = _mapper_or_none(scls) diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index 94baeeb51..291608b6c 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -19,6 +19,9 @@ from ... import event from . import clsregistry import collections import weakref +from sqlalchemy.orm import instrumentation + +declared_attr = declarative_props = None def _declared_mapping_info(cls): @@ -32,322 +35,407 @@ def _declared_mapping_info(cls): return None +def _get_immediate_cls_attr(cls, attrname): + """return an attribute of the class that is either present directly + on the class, e.g. not on a superclass, or is from a superclass but + this superclass is a mixin, that is, not a descendant of + the declarative base. + + This is used to detect attributes that indicate something about + a mapped class independently from any mapped classes that it may + inherit from. + + """ + for base in cls.__mro__: + _is_declarative_inherits = hasattr(base, '_decl_class_registry') + if attrname in base.__dict__: + value = getattr(base, attrname) + if (base is cls or + (base in cls.__bases__ and not _is_declarative_inherits)): + return value + else: + return None + + def _as_declarative(cls, classname, dict_): - from .api import declared_attr + global declared_attr, declarative_props + if declared_attr is None: + from .api import declared_attr + declarative_props = (declared_attr, util.classproperty) - # dict_ will be a dictproxy, which we can't write to, and we need to! - dict_ = dict(dict_) + if _get_immediate_cls_attr(cls, '__abstract__'): + return - column_copies = {} - potential_columns = {} + _MapperConfig.setup_mapping(cls, classname, dict_) - mapper_args_fn = None - table_args = inherited_table_args = None - tablename = None - declarative_props = (declared_attr, util.classproperty) +class _MapperConfig(object): - for base in cls.__mro__: - _is_declarative_inherits = hasattr(base, '_decl_class_registry') + @classmethod + def setup_mapping(cls, cls_, classname, dict_): + defer_map = _get_immediate_cls_attr( + cls_, '_sa_decl_prepare_nocascade') or \ + hasattr(cls_, '_sa_decl_prepare') - if '__declare_last__' in base.__dict__: - @event.listens_for(mapper, "after_configured") - def go(): - cls.__declare_last__() - if '__declare_first__' in base.__dict__: - @event.listens_for(mapper, "before_configured") - def go(): - cls.__declare_first__() - if '__abstract__' in base.__dict__ and base.__abstract__: - if (base is cls or - (base in cls.__bases__ and not _is_declarative_inherits)): - return + if defer_map: + cfg_cls = _DeferredMapperConfig + else: + cfg_cls = _MapperConfig + cfg_cls(cls_, classname, dict_) - class_mapped = _declared_mapping_info(base) is not None + def __init__(self, cls_, classname, dict_): - for name, obj in vars(base).items(): - if name == '__mapper_args__': - if not mapper_args_fn and ( - not class_mapped or - isinstance(obj, declarative_props) - ): - # don't even invoke __mapper_args__ until - # after we've determined everything about the - # mapped table. - # make a copy of it so a class-level dictionary - # is not overwritten when we update column-based - # arguments. - mapper_args_fn = lambda: dict(cls.__mapper_args__) - elif name == '__tablename__': - if not tablename and ( - not class_mapped or - isinstance(obj, declarative_props) - ): - tablename = cls.__tablename__ - elif name == '__table_args__': - if not table_args and ( - not class_mapped or - isinstance(obj, declarative_props) - ): - table_args = cls.__table_args__ - if not isinstance(table_args, (tuple, dict, type(None))): - raise exc.ArgumentError( - "__table_args__ value must be a tuple, " - "dict, or None") - if base is not cls: - inherited_table_args = True - elif class_mapped: - if isinstance(obj, declarative_props): - util.warn("Regular (i.e. not __special__) " - "attribute '%s.%s' uses @declared_attr, " - "but owning class %s is mapped - " - "not applying to subclass %s." - % (base.__name__, name, base, cls)) - continue - elif base is not cls: - # we're a mixin. - if isinstance(obj, Column): - if getattr(cls, name) is not obj: - # if column has been overridden - # (like by the InstrumentedAttribute of the - # superclass), skip + self.cls = cls_ + + # dict_ will be a dictproxy, which we can't write to, and we need to! + self.dict_ = dict(dict_) + self.classname = classname + self.mapped_table = None + self.properties = util.OrderedDict() + self.declared_columns = set() + self.column_copies = {} + self._setup_declared_events() + + # register up front, so that @declared_attr can memoize + # function evaluations in .info + manager = instrumentation.register_class(self.cls) + manager.info['declared_attr_reg'] = {} + + self._scan_attributes() + + clsregistry.add_class(self.classname, self.cls) + + self._extract_mappable_attributes() + + self._extract_declared_columns() + + self._setup_table() + + self._setup_inheritance() + + self._early_mapping() + + def _early_mapping(self): + self.map() + + def _setup_declared_events(self): + if _get_immediate_cls_attr(self.cls, '__declare_last__'): + @event.listens_for(mapper, "after_configured") + def after_configured(): + self.cls.__declare_last__() + + if _get_immediate_cls_attr(self.cls, '__declare_first__'): + @event.listens_for(mapper, "before_configured") + def before_configured(): + self.cls.__declare_first__() + + def _scan_attributes(self): + cls = self.cls + dict_ = self.dict_ + column_copies = self.column_copies + mapper_args_fn = None + table_args = inherited_table_args = None + tablename = None + + for base in cls.__mro__: + class_mapped = base is not cls and \ + _declared_mapping_info(base) is not None and \ + not _get_immediate_cls_attr(base, '_sa_decl_prepare_nocascade') + + if not class_mapped and base is not cls: + self._produce_column_copies(base) + + for name, obj in vars(base).items(): + if name == '__mapper_args__': + if not mapper_args_fn and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + mapper_args_fn = lambda: dict(cls.__mapper_args__) + elif name == '__tablename__': + if not tablename and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + tablename = cls.__tablename__ + elif name == '__table_args__': + if not table_args and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + table_args = cls.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None))): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None") + if base is not cls: + inherited_table_args = True + elif class_mapped: + if isinstance(obj, declarative_props): + util.warn("Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls)) + continue + elif base is not cls: + # we're a mixin, abstract base, or something that is + # acting like that for now. + if isinstance(obj, Column): + # already copied columns to the mapped class. continue - if obj.foreign_keys: + elif isinstance(obj, MapperProperty): raise exc.InvalidRequestError( - "Columns with foreign keys to other columns " - "must be declared as @declared_attr callables " - "on declarative mixin classes. ") - if name not in dict_ and not ( - '__table__' in dict_ and - (obj.name or name) in dict_['__table__'].c - ) and name not in potential_columns: - potential_columns[name] = \ - column_copies[obj] = \ - obj.copy() - column_copies[obj]._creation_order = \ - obj._creation_order - elif isinstance(obj, MapperProperty): + "Mapper properties (i.e. deferred," + "column_property(), relationship(), etc.) must " + "be declared as @declared_attr callables " + "on declarative mixin classes.") + elif isinstance(obj, declarative_props): + oldclassprop = isinstance(obj, util.classproperty) + if not oldclassprop and obj._cascading: + dict_[name] = column_copies[obj] = \ + ret = obj.__get__(obj, cls) + else: + if oldclassprop: + util.warn_deprecated( + "Use of sqlalchemy.util.classproperty on " + "declarative classes is deprecated.") + dict_[name] = column_copies[obj] = \ + ret = getattr(cls, name) + if isinstance(ret, (Column, MapperProperty)) and \ + ret.doc is None: + ret.doc = obj.__doc__ + + if inherited_table_args and not tablename: + table_args = None + + self.table_args = table_args + self.tablename = tablename + self.mapper_args_fn = mapper_args_fn + + def _produce_column_copies(self, base): + cls = self.cls + dict_ = self.dict_ + column_copies = self.column_copies + # copy mixin columns to the mapped class + for name, obj in vars(base).items(): + if isinstance(obj, Column): + if getattr(cls, name) is not obj: + # if column has been overridden + # (like by the InstrumentedAttribute of the + # superclass), skip + continue + elif obj.foreign_keys: raise exc.InvalidRequestError( - "Mapper properties (i.e. deferred," - "column_property(), relationship(), etc.) must " - "be declared as @declared_attr callables " - "on declarative mixin classes.") - elif isinstance(obj, declarative_props): - dict_[name] = ret = \ - column_copies[obj] = getattr(cls, name) - if isinstance(ret, (Column, MapperProperty)) and \ - ret.doc is None: - ret.doc = obj.__doc__ - - # apply inherited columns as we should - for k, v in potential_columns.items(): - dict_[k] = v - - if inherited_table_args and not tablename: - table_args = None - - clsregistry.add_class(classname, cls) - our_stuff = util.OrderedDict() - - for k in list(dict_): - - # TODO: improve this ? all dunders ? - if k in ('__table__', '__tablename__', '__mapper_args__'): - continue - - value = dict_[k] - if isinstance(value, declarative_props): - value = getattr(cls, k) - - elif isinstance(value, QueryableAttribute) and \ - value.class_ is not cls and \ - value.key != k: - # detect a QueryableAttribute that's already mapped being - # assigned elsewhere in userland, turn into a synonym() - value = synonym(value.key) - setattr(cls, k, value) - - if (isinstance(value, tuple) and len(value) == 1 and - isinstance(value[0], (Column, MapperProperty))): - util.warn("Ignoring declarative-like tuple value of attribute " - "%s: possibly a copy-and-paste error with a comma " - "left at the end of the line?" % k) - continue - if not isinstance(value, (Column, MapperProperty)): - if not k.startswith('__'): - dict_.pop(k) - setattr(cls, k, value) - continue - if k == 'metadata': - raise exc.InvalidRequestError( - "Attribute name 'metadata' is reserved " - "for the MetaData instance when using a " - "declarative base class." - ) - prop = clsregistry._deferred_relationship(cls, value) - our_stuff[k] = prop - - # set up attributes in the order they were created - our_stuff.sort(key=lambda key: our_stuff[key]._creation_order) - - # extract columns from the class dict - declared_columns = set() - name_to_prop_key = collections.defaultdict(set) - for key, c in list(our_stuff.items()): - if isinstance(c, (ColumnProperty, CompositeProperty)): - for col in c.columns: - if isinstance(col, Column) and \ - col.table is None: - _undefer_column_name(key, col) - if not isinstance(c, CompositeProperty): - name_to_prop_key[col.name].add(key) - declared_columns.add(col) - elif isinstance(c, Column): - _undefer_column_name(key, c) - name_to_prop_key[c.name].add(key) - declared_columns.add(c) - # if the column is the same name as the key, - # remove it from the explicit properties dict. - # the normal rules for assigning column-based properties - # will take over, including precedence of columns - # in multi-column ColumnProperties. - if key == c.key: - del our_stuff[key] - - for name, keys in name_to_prop_key.items(): - if len(keys) > 1: - util.warn( - "On class %r, Column object %r named directly multiple times, " - "only one will be used: %s" % - (classname, name, (", ".join(sorted(keys)))) - ) + "Columns with foreign keys to other columns " + "must be declared as @declared_attr callables " + "on declarative mixin classes. ") + elif name not in dict_ and not ( + '__table__' in dict_ and + (obj.name or name) in dict_['__table__'].c + ): + column_copies[obj] = copy_ = obj.copy() + copy_._creation_order = obj._creation_order + setattr(cls, name, copy_) + dict_[name] = copy_ - declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order) - table = None + def _extract_mappable_attributes(self): + cls = self.cls + dict_ = self.dict_ - if hasattr(cls, '__table_cls__'): - table_cls = util.unbound_method_to_callable(cls.__table_cls__) - else: - table_cls = Table - - if '__table__' not in dict_: - if tablename is not None: - - args, table_kw = (), {} - if table_args: - if isinstance(table_args, dict): - table_kw = table_args - elif isinstance(table_args, tuple): - if isinstance(table_args[-1], dict): - args, table_kw = table_args[0:-1], table_args[-1] - else: - args = table_args - - autoload = dict_.get('__autoload__') - if autoload: - table_kw['autoload'] = True - - cls.__table__ = table = table_cls( - tablename, cls.metadata, - *(tuple(declared_columns) + tuple(args)), - **table_kw) - else: - table = cls.__table__ - if declared_columns: - for c in declared_columns: - if not table.c.contains_column(c): - raise exc.ArgumentError( - "Can't add additional column %r when " - "specifying __table__" % c.key - ) + our_stuff = self.properties - if hasattr(cls, '__mapper_cls__'): - mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) - else: - mapper_cls = mapper + for k in list(dict_): - for c in cls.__bases__: - if _declared_mapping_info(c) is not None: - inherits = c - break - else: - inherits = None + if k in ('__table__', '__tablename__', '__mapper_args__'): + continue - if table is None and inherits is None: - raise exc.InvalidRequestError( - "Class %r does not have a __table__ or __tablename__ " - "specified and does not inherit from an existing " - "table-mapped class." % cls - ) - elif inherits: - inherited_mapper = _declared_mapping_info(inherits) - inherited_table = inherited_mapper.local_table - inherited_mapped_table = inherited_mapper.mapped_table - - if table is None: - # single table inheritance. - # ensure no table args - if table_args: - raise exc.ArgumentError( - "Can't place __table_args__ on an inherited class " - "with no table." + value = dict_[k] + if isinstance(value, declarative_props): + value = getattr(cls, k) + + elif isinstance(value, QueryableAttribute) and \ + value.class_ is not cls and \ + value.key != k: + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = synonym(value.key) + setattr(cls, k, value) + + if (isinstance(value, tuple) and len(value) == 1 and + isinstance(value[0], (Column, MapperProperty))): + util.warn("Ignoring declarative-like tuple value of attribute " + "%s: possibly a copy-and-paste error with a comma " + "left at the end of the line?" % k) + continue + elif not isinstance(value, (Column, MapperProperty)): + # using @declared_attr for some object that + # isn't Column/MapperProperty; remove from the dict_ + # and place the evaulated value onto the class. + if not k.startswith('__'): + dict_.pop(k) + setattr(cls, k, value) + continue + # we expect to see the name 'metadata' in some valid cases; + # however at this point we see it's assigned to something trying + # to be mapped, so raise for that. + elif k == 'metadata': + raise exc.InvalidRequestError( + "Attribute name 'metadata' is reserved " + "for the MetaData instance when using a " + "declarative base class." + ) + prop = clsregistry._deferred_relationship(cls, value) + our_stuff[k] = prop + + def _extract_declared_columns(self): + our_stuff = self.properties + + # set up attributes in the order they were created + our_stuff.sort(key=lambda key: our_stuff[key]._creation_order) + + # extract columns from the class dict + declared_columns = self.declared_columns + name_to_prop_key = collections.defaultdict(set) + for key, c in list(our_stuff.items()): + if isinstance(c, (ColumnProperty, CompositeProperty)): + for col in c.columns: + if isinstance(col, Column) and \ + col.table is None: + _undefer_column_name(key, col) + if not isinstance(c, CompositeProperty): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + elif isinstance(c, Column): + _undefer_column_name(key, c) + name_to_prop_key[c.name].add(key) + declared_columns.add(c) + # if the column is the same name as the key, + # remove it from the explicit properties dict. + # the normal rules for assigning column-based properties + # will take over, including precedence of columns + # in multi-column ColumnProperties. + if key == c.key: + del our_stuff[key] + + for name, keys in name_to_prop_key.items(): + if len(keys) > 1: + util.warn( + "On class %r, Column object %r named " + "directly multiple times, " + "only one will be used: %s" % + (self.classname, name, (", ".join(sorted(keys)))) ) - # add any columns declared here to the inherited table. - for c in declared_columns: - if c.primary_key: - raise exc.ArgumentError( - "Can't place primary key columns on an inherited " - "class with no table." - ) - if c.name in inherited_table.c: - if inherited_table.c[c.name] is c: - continue - raise exc.ArgumentError( - "Column '%s' on class %s conflicts with " - "existing column '%s'" % - (c, cls, inherited_table.c[c.name]) - ) - inherited_table.append_column(c) - if inherited_mapped_table is not None and \ - inherited_mapped_table is not inherited_table: - inherited_mapped_table._refresh_for_new_column(c) - - defer_map = hasattr(cls, '_sa_decl_prepare') - if defer_map: - cfg_cls = _DeferredMapperConfig - else: - cfg_cls = _MapperConfig - mt = cfg_cls(mapper_cls, - cls, table, - inherits, - declared_columns, - column_copies, - our_stuff, - mapper_args_fn) - if not defer_map: - mt.map() + def _setup_table(self): + cls = self.cls + tablename = self.tablename + table_args = self.table_args + dict_ = self.dict_ + declared_columns = self.declared_columns -class _MapperConfig(object): + declared_columns = self.declared_columns = sorted( + declared_columns, key=lambda c: c._creation_order) + table = None - mapped_table = None - - def __init__(self, mapper_cls, - cls, - table, - inherits, - declared_columns, - column_copies, - properties, mapper_args_fn): - self.mapper_cls = mapper_cls - self.cls = cls + if hasattr(cls, '__table_cls__'): + table_cls = util.unbound_method_to_callable(cls.__table_cls__) + else: + table_cls = Table + + if '__table__' not in dict_: + if tablename is not None: + + args, table_kw = (), {} + if table_args: + if isinstance(table_args, dict): + table_kw = table_args + elif isinstance(table_args, tuple): + if isinstance(table_args[-1], dict): + args, table_kw = table_args[0:-1], table_args[-1] + else: + args = table_args + + autoload = dict_.get('__autoload__') + if autoload: + table_kw['autoload'] = True + + cls.__table__ = table = table_cls( + tablename, cls.metadata, + *(tuple(declared_columns) + tuple(args)), + **table_kw) + else: + table = cls.__table__ + if declared_columns: + for c in declared_columns: + if not table.c.contains_column(c): + raise exc.ArgumentError( + "Can't add additional column %r when " + "specifying __table__" % c.key + ) self.local_table = table - self.inherits = inherits - self.properties = properties - self.mapper_args_fn = mapper_args_fn - self.declared_columns = declared_columns - self.column_copies = column_copies + + def _setup_inheritance(self): + table = self.local_table + cls = self.cls + table_args = self.table_args + declared_columns = self.declared_columns + for c in cls.__bases__: + if _declared_mapping_info(c) is not None and \ + not _get_immediate_cls_attr( + c, '_sa_decl_prepare_nocascade'): + self.inherits = c + break + else: + self.inherits = None + + if table is None and self.inherits is None and \ + not _get_immediate_cls_attr(cls, '__no_table__'): + + raise exc.InvalidRequestError( + "Class %r does not have a __table__ or __tablename__ " + "specified and does not inherit from an existing " + "table-mapped class." % cls + ) + elif self.inherits: + inherited_mapper = _declared_mapping_info(self.inherits) + inherited_table = inherited_mapper.local_table + inherited_mapped_table = inherited_mapper.mapped_table + + if table is None: + # single table inheritance. + # ensure no table args + if table_args: + raise exc.ArgumentError( + "Can't place __table_args__ on an inherited class " + "with no table." + ) + # add any columns declared here to the inherited table. + for c in declared_columns: + if c.primary_key: + raise exc.ArgumentError( + "Can't place primary key columns on an inherited " + "class with no table." + ) + if c.name in inherited_table.c: + if inherited_table.c[c.name] is c: + continue + raise exc.ArgumentError( + "Column '%s' on class %s conflicts with " + "existing column '%s'" % + (c, cls, inherited_table.c[c.name]) + ) + inherited_table.append_column(c) + if inherited_mapped_table is not None and \ + inherited_mapped_table is not inherited_table: + inherited_mapped_table._refresh_for_new_column(c) def _prepare_mapper_arguments(self): properties = self.properties @@ -401,20 +489,31 @@ class _MapperConfig(object): properties[k] = [col] + p.columns result_mapper_args = mapper_args.copy() result_mapper_args['properties'] = properties - return result_mapper_args + self.mapper_args = result_mapper_args def map(self): - mapper_args = self._prepare_mapper_arguments() - self.cls.__mapper__ = self.mapper_cls( + self._prepare_mapper_arguments() + if hasattr(self.cls, '__mapper_cls__'): + mapper_cls = util.unbound_method_to_callable( + self.cls.__mapper_cls__) + else: + mapper_cls = mapper + + self.cls.__mapper__ = mp_ = mapper_cls( self.cls, self.local_table, - **mapper_args + **self.mapper_args ) + del mp_.class_manager.info['declared_attr_reg'] + return mp_ class _DeferredMapperConfig(_MapperConfig): _configs = util.OrderedDict() + def _early_mapping(self): + pass + @property def cls(self): return self._cls() @@ -466,7 +565,7 @@ class _DeferredMapperConfig(_MapperConfig): def map(self): self._configs.pop(self._cls, None) - super(_DeferredMapperConfig, self).map() + return super(_DeferredMapperConfig, self).map() def _add_attribute(cls, key, value): diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 4595b857a..3ef63a5ae 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -103,7 +103,12 @@ class _MultipleClassMarker(object): self.on_remove() def add_item(self, item): - modules = set([cls().__module__ for cls in self.contents]) + # protect against class registration race condition against + # asynchronous garbage collection calling _remove_item, + # [ticket:3208] + modules = set([ + cls.__module__ for cls in + [ref() for ref in self.contents] if cls is not None]) if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 67fda44c4..61155731c 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -119,7 +119,7 @@ start numbering at 1 or some other integer, provide ``count_from=1``. """ -from ..orm.collections import collection +from ..orm.collections import collection, collection_adapter from .. import util __all__ = ['ordering_list'] @@ -319,7 +319,10 @@ class OrderingList(list): def remove(self, entity): super(OrderingList, self).remove(entity) - self._reorder() + + adapter = collection_adapter(self) + if adapter and adapter._referenced_by_owner: + self._reorder() def pop(self, index=-1): entity = super(OrderingList, self).pop(index) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 1fc0873bd..356a8a3b9 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -589,6 +589,16 @@ class CollectionAdapter(object): "The entity collection being adapted." return self._data() + @property + def _referenced_by_owner(self): + """return True if the owner state still refers to this collection. + + This will return False within a bulk replace operation, + where this collection is the one being replaced. + + """ + return self.owner_state.dict[self._key] is self._data() + @util.memoized_property def attr(self): return self.owner_state.manager[self._key].impl @@ -851,11 +861,24 @@ def _instrument_class(cls): "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") + roles, methods = _locate_roles_and_methods(cls) + + _setup_canned_roles(cls, roles, methods) + + _assert_required_roles(cls, roles, methods) + + _set_collection_attributes(cls, roles, methods) + + +def _locate_roles_and_methods(cls): + """search for _sa_instrument_role-decorated methods in + method resolution order, assign to roles. + + """ + roles = {} methods = {} - # search for _sa_instrument_role-decorated methods in - # method resolution order, assign to roles for supercls in cls.__mro__: for name, method in vars(supercls).items(): if not util.callable(method): @@ -880,14 +903,19 @@ def _instrument_class(cls): assert op in ('fire_append_event', 'fire_remove_event') after = op if before: - methods[name] = before[0], before[1], after + methods[name] = before + (after, ) elif after: methods[name] = None, None, after + return roles, methods + - # see if this class has "canned" roles based on a known - # collection type (dict, set, list). Apply those roles - # as needed to the "roles" dictionary, and also - # prepare "decorator" methods +def _setup_canned_roles(cls, roles, methods): + """see if this class has "canned" roles based on a known + collection type (dict, set, list). Apply those roles + as needed to the "roles" dictionary, and also + prepare "decorator" methods + + """ collection_type = util.duck_type_collection(cls) if collection_type in __interfaces: canned_roles, decorators = __interfaces[collection_type] @@ -901,8 +929,12 @@ def _instrument_class(cls): not hasattr(fn, '_sa_instrumented')): setattr(cls, method, decorator(fn)) - # ensure all roles are present, and apply implicit instrumentation if - # needed + +def _assert_required_roles(cls, roles, methods): + """ensure all roles are present, and apply implicit instrumentation if + needed + + """ if 'appender' not in roles or not hasattr(cls, roles['appender']): raise sa_exc.ArgumentError( "Type %s must elect an appender method to be " @@ -924,8 +956,12 @@ def _instrument_class(cls): "Type %s must elect an iterator method to be " "a collection class" % cls.__name__) - # apply ad-hoc instrumentation from decorators, class-level defaults - # and implicit role declarations + +def _set_collection_attributes(cls, roles, methods): + """apply ad-hoc instrumentation from decorators, class-level defaults + and implicit role declarations + + """ for method_name, (before, argument, after) in methods.items(): setattr(cls, method_name, _instrument_membership_mutator(getattr(cls, method_name), diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index c50a7b062..9ea0dd834 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -61,7 +61,8 @@ class InstrumentationEvents(event.Events): @classmethod def _listen(cls, event_key, propagate=True, **kw): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn def listen(target_cls, *arg): listen_cls = target() @@ -192,7 +193,8 @@ class InstanceEvents(event.Events): @classmethod def _listen(cls, event_key, raw=False, propagate=False, **kw): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn if not raw: def wrap(state, *arg, **kw): @@ -498,7 +500,8 @@ class MapperEvents(event.Events): def _listen( cls, event_key, raw=False, retval=False, propagate=False, **kw): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn if identifier in ("before_configured", "after_configured") and \ target is not mapperlib.Mapper: @@ -1493,7 +1496,8 @@ class AttributeEvents(event.Events): propagate=False): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn if active_history: target.dispatch._active_history = True diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a59a38a5b..2ab239f86 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -426,6 +426,12 @@ class Mapper(InspectionAttr): thus persisting the value to the ``discriminator`` column in the database. + .. warning:: + + Currently, **only one discriminator column may be set**, typically + on the base-most class in the hierarchy. "Cascading" polymorphic + columns are not yet supported. + .. seealso:: :ref:`inheritance_toplevel` @@ -1080,6 +1086,9 @@ class Mapper(InspectionAttr): auto-session attachment logic. """ + + # when using declarative as of 1.0, the register_class has + # already happened from within declarative. manager = attributes.manager_of_class(self.class_) if self.non_primary: @@ -1102,18 +1111,14 @@ class Mapper(InspectionAttr): "create a non primary Mapper. clear_mappers() will " "remove *all* current mappers from all classes." % self.class_) - # else: - # a ClassManager may already exist as - # ClassManager.instrument_attribute() creates - # new managers for each subclass if they don't yet exist. + + if manager is None: + manager = instrumentation.register_class(self.class_) _mapper_registry[self] = True self.dispatch.instrument_class(self, self.class_) - if manager is None: - manager = instrumentation.register_class(self.class_) - self.class_manager = manager manager.mapper = self diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 60948293b..7b2ea7977 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1145,7 +1145,8 @@ class Query(object): @_generative() def with_hint(self, selectable, text, dialect_name='*'): - """Add an indexing hint for the given entity or selectable to + """Add an indexing or other executional context + hint for the given entity or selectable to this :class:`.Query`. Functionality is passed straight through to @@ -1153,11 +1154,35 @@ class Query(object): with the addition that ``selectable`` can be a :class:`.Table`, :class:`.Alias`, or ORM entity / mapped class /etc. + + .. seealso:: + + :meth:`.Query.with_statement_hint` + """ - selectable = inspect(selectable).selectable + if selectable is not None: + selectable = inspect(selectable).selectable self._with_hints += ((selectable, text, dialect_name),) + def with_statement_hint(self, text, dialect_name='*'): + """add a statement hint to this :class:`.Select`. + + This method is similar to :meth:`.Select.with_hint` except that + it does not require an individual table, and instead applies to the + statement as a whole. + + This feature calls down into :meth:`.Select.with_statement_hint`. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Query.with_hint` + + """ + return self.with_hint(None, text, dialect_name) + @_generative() def execution_options(self, **kwargs): """ Set non-SQL options which take effect during execution. @@ -2591,6 +2616,19 @@ class Query(object): SELECT 1 FROM users WHERE users.name = :name_1 ) AS anon_1 + The EXISTS construct is usually used in the WHERE clause:: + + session.query(User.id).filter(q.exists()).scalar() + + Note that some databases such as SQL Server don't allow an + EXISTS expression to be present in the columns clause of a + SELECT. To select a simple boolean value based on the exists + as a WHERE, use :func:`.literal`:: + + from sqlalchemy import literal + + session.query(literal(True)).filter(q.exists()).scalar() + .. versionadded:: 0.8.1 """ diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 95ff21444..56a33742d 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2181,7 +2181,7 @@ class JoinCondition(object): elif self._local_remote_pairs or self._remote_side: self._annotate_remote_from_args() elif self._refers_to_parent_table(): - self._annotate_selfref(lambda col: "foreign" in col._annotations) + self._annotate_selfref(lambda col: "foreign" in col._annotations, False) elif self._tables_overlap(): self._annotate_remote_with_overlap() else: @@ -2200,7 +2200,7 @@ class JoinCondition(object): self.secondaryjoin = visitors.replacement_traverse( self.secondaryjoin, {}, repl) - def _annotate_selfref(self, fn): + def _annotate_selfref(self, fn, remote_side_given): """annotate 'remote' in primaryjoin, secondaryjoin when the relationship is detected as self-referential. @@ -2215,7 +2215,7 @@ class JoinCondition(object): if fn(binary.right) and not equated: binary.right = binary.right._annotate( {"remote": True}) - else: + elif not remote_side_given: self._warn_non_column_elements() self.primaryjoin = visitors.cloned_traverse( @@ -2240,7 +2240,7 @@ class JoinCondition(object): remote_side = self._remote_side if self._refers_to_parent_table(): - self._annotate_selfref(lambda col: col in remote_side) + self._annotate_selfref(lambda col: col in remote_side, True) else: def repl(element): if element in remote_side: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d59012d12..86f00d944 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -24,12 +24,10 @@ To generate user-defined SQL strings, see """ import re -from . import schema, sqltypes, operators, functions, \ - util as sql_util, visitors, elements, selectable, base +from . import schema, sqltypes, operators, functions, visitors, \ + elements, selectable, crud from .. import util, exc -import decimal import itertools -import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -64,17 +62,6 @@ BIND_TEMPLATES = { 'named': ":%(name)s" } -REQUIRED = util.symbol('REQUIRED', """ -Placeholder for the value within a :class:`.BindParameter` -which is required to be present when the statement is passed -to :meth:`.Connection.execute`. - -This symbol is typically used when a :func:`.expression.insert` -or :func:`.expression.update` statement is compiled without parameter -values present. - -""") - OPERATORS = { # binary @@ -725,7 +712,6 @@ class SQLCompiler(Compiled): for c in clauselist.clauses) if s) - def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: @@ -825,7 +811,8 @@ class SQLCompiler(Compiled): text += " GROUP BY " + group_by text += self.order_by_clause(cs, **kwargs) - text += (cs._limit_clause is not None or cs._offset_clause is not None) and \ + text += (cs._limit_clause is not None + or cs._offset_clause is not None) and \ self.limit_clause(cs) or "" if self.ctes and \ @@ -882,15 +869,15 @@ class SQLCompiler(Compiled): isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - operator = binary.operator - disp = getattr(self, "visit_%s_binary" % operator.__name__, None) + operator_ = binary.operator + disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) if disp: - return disp(binary, operator, **kw) + return disp(binary, operator_, **kw) else: try: - opstring = OPERATORS[operator] + opstring = OPERATORS[operator_] except KeyError: - raise exc.UnsupportedCompilationError(self, operator) + raise exc.UnsupportedCompilationError(self, operator_) else: return self._generate_generic_binary(binary, opstring, **kw) @@ -972,7 +959,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -983,7 +970,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -994,7 +981,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -1005,7 +992,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) @@ -1337,6 +1324,9 @@ class SQLCompiler(Compiled): def get_crud_hint_text(self, table, text): return None + def get_statement_hint_text(self, hint_texts): + return " ".join(hint_texts) + def _transform_select_for_nested_joins(self, select): """Rewrite any "a JOIN (b JOIN c)" expression as "a JOIN (select * from b JOIN c) AS anon", to support @@ -1507,29 +1497,7 @@ class SQLCompiler(Compiled): select, transformed_select) return text - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] - - if asfrom: - froms = select._get_display_froms( - explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) - else: - froms = select._get_display_froms( - explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) - - new_correlate_froms = set(selectable._from_objects(*froms)) - all_correlate_froms = new_correlate_froms.union(correlate_froms) - - new_entry = { - 'asfrom_froms': new_correlate_froms, - 'iswrapper': iswrapper, - 'correlate_froms': all_correlate_froms, - 'selectable': select, - } - self.stack.append(new_entry) + froms = self._setup_select_stack(select, entry, asfrom, iswrapper) column_clause_args = kwargs.copy() column_clause_args.update({ @@ -1540,18 +1508,11 @@ class SQLCompiler(Compiled): text = "SELECT " # we're off to a good start ! if select._hints: - byfrom = dict([ - (from_, hinttext % { - 'name': from_._compiler_dispatch( - self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.items() - if dialect in ('*', self.dialect.name) - ]) - hint_text = self.get_select_hint_text(byfrom) + hint_text, byfrom = self._setup_select_hints(select) if hint_text: text += hint_text + " " + else: + byfrom = None if select._prefixes: text += self._generate_prefixes( @@ -1572,6 +1533,70 @@ class SQLCompiler(Compiled): if c is not None ] + text = self._compose_select_body( + text, select, inner_columns, froms, byfrom, kwargs) + + if select._statement_hints: + per_dialect = [ + ht for (dialect_name, ht) + in select._statement_hints + if dialect_name in ('*', self.dialect.name) + ] + if per_dialect: + text += " " + self.get_statement_hint_text(per_dialect) + + if self.ctes and \ + compound_index == 0 and toplevel: + text = self._render_cte_clause() + text + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def _setup_select_hints(self, select): + byfrom = dict([ + (from_, hinttext % { + 'name': from_._compiler_dispatch( + self, ashint=True) + }) + for (from_, dialect), hinttext in + select._hints.items() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + return hint_text, byfrom + + def _setup_select_stack(self, select, entry, asfrom, iswrapper): + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] + + if asfrom: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + + new_correlate_froms = set(selectable._from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) + + new_entry = { + 'asfrom_froms': new_correlate_froms, + 'iswrapper': iswrapper, + 'correlate_froms': all_correlate_froms, + 'selectable': select, + } + self.stack.append(new_entry) + return froms + + def _compose_select_body( + self, text, select, inner_columns, froms, byfrom, kwargs): text += ', '.join(inner_columns) if froms: @@ -1615,16 +1640,7 @@ class SQLCompiler(Compiled): if select._for_update_arg is not None: text += self.for_update_clause(select, **kwargs) - if self.ctes and \ - compound_index == 0 and toplevel: - text = self._render_cte_clause() + text - - self.stack.pop(-1) - - if asfrom and parens: - return "(" + text + ")" - else: - return text + return text def _generate_prefixes(self, stmt, prefixes, **kw): clause = " ".join( @@ -1714,9 +1730,9 @@ class SQLCompiler(Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - colparams = self._get_colparams(insert_stmt, **kw) + crud_params = crud._get_crud_params(self, insert_stmt, **kw) - if not colparams and \ + if not crud_params and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError("The '%s' dialect with current database " @@ -1731,9 +1747,9 @@ class SQLCompiler(Compiled): "version settings does not support " "in-place multirow inserts." % self.dialect.name) - colparams_single = colparams[0] + crud_params_single = crud_params[0] else: - colparams_single = colparams + crud_params_single = crud_params preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -1764,9 +1780,9 @@ class SQLCompiler(Compiled): text += table_text - if colparams_single or not supports_default_values: + if crud_params_single or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in colparams_single]) + for c in crud_params_single]) if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning @@ -1778,20 +1794,20 @@ class SQLCompiler(Compiled): if insert_stmt.select is not None: text += " %s" % self.process(insert_stmt.select, **kw) - elif not colparams and supports_default_values: + elif not crud_params and supports_default_values: text += " DEFAULT VALUES" elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( ", ".join( "(%s)" % ( - ', '.join(c[1] for c in colparam_set) + ', '.join(c[1] for c in crud_param_set) ) - for colparam_set in colparams + for crud_param_set in crud_params ) ) else: text += " VALUES (%s)" % \ - ', '.join([c[1] for c in colparams]) + ', '.join([c[1] for c in crud_params]) if self.returning and not self.returning_precedes_values: text += " " + returning_clause @@ -1848,7 +1864,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, **kw) + crud_params = crud._get_crud_params(self, update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1875,7 +1891,7 @@ class SQLCompiler(Compiled): text += ', '.join( c[0]._compiler_dispatch(self, include_table=include_table) + - '=' + c[1] for c in colparams + '=' + c[1] for c in crud_params ) if self.returning or update_stmt._returning: @@ -1911,380 +1927,9 @@ class SQLCompiler(Compiled): return text - def _create_crud_bind_param(self, col, value, required=False, name=None): - if name is None: - name = col.key - bindparam = elements.BindParameter(name, value, - type_=col.type, required=required) - bindparam._is_crud = True - return bindparam._compiler_dispatch(self) - @util.memoized_property def _key_getters_for_crud_column(self): - if self.isupdate and self.statement._extra_froms: - # when extra tables are present, refer to the columns - # in those extra tables as table-qualified, including in - # dictionaries and when rendering bind param names. - # the "main" table of the statement remains unqualified, - # allowing the most compatibility with a non-multi-table - # statement. - _et = set(self.statement._extra_froms) - - def _column_as_key(key): - str_key = elements._column_as_key(key) - if hasattr(key, 'table') and key.table in _et: - return (key.table.name, str_key) - else: - return str_key - - def _getattr_col_key(col): - if col.table in _et: - return (col.table.name, col.key) - else: - return col.key - - def _col_bind_name(col): - if col.table in _et: - return "%s_%s" % (col.table.name, col.key) - else: - return col.key - - else: - _column_as_key = elements._column_as_key - _getattr_col_key = _col_bind_name = operator.attrgetter("key") - - return _column_as_key, _getattr_col_key, _col_bind_name - - def _get_colparams(self, stmt, **kw): - """create a set of tuples representing column/string pairs for use - in an INSERT or UPDATE statement. - - Also generates the Compiled object's postfetch, prefetch, and - returning column collections, used for default handling and ultimately - populating the ResultProxy's prefetch_cols() and postfetch_cols() - collections. - - """ - - self.postfetch = [] - self.prefetch = [] - self.returning = [] - - # no parameters in the statement, no parameters in the - # compiled params - return binds for all columns - if self.column_keys is None and stmt.parameters is None: - return [ - (c, self._create_crud_bind_param(c, - None, required=True)) - for c in stmt.table.columns - ] - - if stmt._has_multi_parameters: - stmt_parameters = stmt.parameters[0] - else: - stmt_parameters = stmt.parameters - - # getters - these are normally just column.key, - # but in the case of mysql multi-table update, the rules for - # .key must conditionally take tablename into account - _column_as_key, _getattr_col_key, _col_bind_name = \ - self._key_getters_for_crud_column - - # if we have statement parameters - set defaults in the - # compiled params - if self.column_keys is None: - parameters = {} - else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in self.column_keys - if not stmt_parameters or - key not in stmt_parameters) - - # create a list of column assignment clauses as tuples - values = [] - - if stmt_parameters is not None: - for k, v in stmt_parameters.items(): - colkey = _column_as_key(k) - if colkey is not None: - parameters.setdefault(colkey, v) - else: - # a non-Column expression on the left side; - # add it to values() in an "as-is" state, - # coercing right side to bound param - if elements._is_literal(v): - v = self.process( - elements.BindParameter(None, v, type_=k.type), - **kw) - else: - v = self.process(v.self_group(), **kw) - - values.append((k, v)) - - need_pks = self.isinsert and \ - not self.inline and \ - not stmt._returning and \ - not stmt._has_multi_parameters - - implicit_returning = need_pks and \ - self.dialect.implicit_returning and \ - stmt.table.implicit_returning - - if self.isinsert: - implicit_return_defaults = (implicit_returning and - stmt._return_defaults) - elif self.isupdate: - implicit_return_defaults = (self.dialect.implicit_returning and - stmt.table.implicit_returning and - stmt._return_defaults) - else: - implicit_return_defaults = False - - if implicit_return_defaults: - if stmt._return_defaults is True: - implicit_return_defaults = set(stmt.table.c) - else: - implicit_return_defaults = set(stmt._return_defaults) - - postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid - - check_columns = {} - - # special logic that only occurs for multi-table UPDATE - # statements - if self.isupdate and stmt._extra_froms and stmt_parameters: - normalized_params = dict( - (elements._clause_element_as_expr(c), param) - for c, param in stmt_parameters.items() - ) - affected_tables = set() - for t in stmt._extra_froms: - for c in t.c: - if c in normalized_params: - affected_tables.add(t) - check_columns[_getattr_col_key(c)] = c - value = normalized_params[c] - if elements._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is REQUIRED, - name=_col_bind_name(c)) - else: - self.postfetch.append(c) - value = self.process(value.self_group(), **kw) - values.append((c, value)) - # determine tables which are actually - # to be updated - process onupdate and - # server_onupdate for these - for t in affected_tables: - for c in t.c: - if c in normalized_params: - continue - elif (c.onupdate is not None and not - c.onupdate.is_sequence): - if c.onupdate.is_clause_element: - values.append( - (c, self.process( - c.onupdate.arg.self_group(), - **kw) - ) - ) - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param( - c, None, name=_col_bind_name(c) - ) - ) - ) - self.prefetch.append(c) - elif c.server_onupdate is not None: - self.postfetch.append(c) - - if self.isinsert and stmt.select_names: - # for an insert from select, we can only use names that - # are given, so only select for those names. - cols = (stmt.table.c[_column_as_key(name)] - for name in stmt.select_names) - else: - # iterate through all table columns to maintain - # ordering, even for those cols that aren't included - cols = stmt.table.columns - - for c in cols: - col_key = _getattr_col_key(c) - if col_key in parameters and col_key not in check_columns: - value = parameters.pop(col_key) - if elements._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is REQUIRED, - name=_col_bind_name(c) - if not stmt._has_multi_parameters - else "%s_0" % _col_bind_name(c) - ) - else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: - value = value._clone() - value.type = c.type - - if c.primary_key and implicit_returning: - self.returning.append(c) - value = self.process(value.self_group(), **kw) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - value = self.process(value.self_group(), **kw) - else: - self.postfetch.append(c) - value = self.process(value.self_group(), **kw) - values.append((c, value)) - - elif self.isinsert: - if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): - - if implicit_returning: - if c.default is not None: - if c.default.is_sequence: - if self.dialect.supports_sequences and \ - (not c.default.optional or - not self.dialect.sequences_optional): - proc = self.process(c.default, **kw) - values.append((c, proc)) - self.returning.append(c) - elif c.default.is_clause_element: - values.append( - (c, self.process( - c.default.arg.self_group(), **kw)) - ) - self.returning.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - else: - self.returning.append(c) - else: - if ( - (c.default is not None and - (not c.default.is_sequence or - self.dialect.supports_sequences)) or - c is stmt.table._autoincrement_column and - (self.dialect.supports_sequences or - self.dialect. - preexecute_autoincrement_sequences) - ): - - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - - self.prefetch.append(c) - - elif c.default is not None: - if c.default.is_sequence: - if self.dialect.supports_sequences and \ - (not c.default.optional or - not self.dialect.sequences_optional): - proc = self.process(c.default, **kw) - values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - self.postfetch.append(c) - elif c.default.is_clause_element: - values.append( - (c, self.process( - c.default.arg.self_group(), **kw)) - ) - - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - # don't add primary key column to postfetch - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - elif c.server_default is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - self.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - - elif self.isupdate: - if c.onupdate is not None and not c.onupdate.is_sequence: - if c.onupdate.is_clause_element: - values.append( - (c, self.process( - c.onupdate.arg.self_group(), **kw)) - ) - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - else: - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - elif c.server_onupdate is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - else: - self.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - - if parameters and stmt_parameters: - check = set(parameters).intersection( - _column_as_key(k) for k in stmt.parameters - ).difference(check_columns) - if check: - raise exc.CompileError( - "Unconsumed column names: %s" % - (", ".join("%s" % c for c in check)) - ) - - if stmt._has_multi_parameters: - values_0 = values - values = [values] - - values.extend( - [ - ( - c, - (self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) if elements._is_literal(row[c.key]) - else self.process( - row[c.key].self_group(), **kw)) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) - ) - - return values + return crud._key_getters_for_crud_column(self) def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), @@ -2468,17 +2113,18 @@ class DDLCompiler(Compiled): constraints.extend([c for c in table._sorted_constraints if c is not table.primary_key]) - return ", \n\t".join(p for p in - (self.process(constraint) - for constraint in constraints - if ( - constraint._create_rule is None or - constraint._create_rule(self)) - and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None - ) + return ", \n\t".join( + p for p in + (self.process(constraint) + for constraint in constraints + if ( + constraint._create_rule is None or + constraint._create_rule(self)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None + ) def visit_drop_table(self, drop): return "\nDROP TABLE " + self.preparer.format_table(drop.element) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py new file mode 100644 index 000000000..1c1f661d2 --- /dev/null +++ b/lib/sqlalchemy/sql/crud.py @@ -0,0 +1,473 @@ +# sql/crud.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Functions used by compiler.py to determine the parameters rendered +within INSERT and UPDATE statements. + +""" +from .. import util +from .. import exc +from . import elements +import operator + +REQUIRED = util.symbol('REQUIRED', """ +Placeholder for the value within a :class:`.BindParameter` +which is required to be present when the statement is passed +to :meth:`.Connection.execute`. + +This symbol is typically used when a :func:`.expression.insert` +or :func:`.expression.update` statement is compiled without parameter +values present. + +""") + + +def _get_crud_params(compiler, stmt, **kw): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the ResultProxy's prefetch_cols() and postfetch_cols() + collections. + + """ + + compiler.postfetch = [] + compiler.prefetch = [] + compiler.returning = [] + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if compiler.column_keys is None and stmt.parameters is None: + return [ + (c, _create_bind_param( + compiler, c, None, required=True)) + for c in stmt.table.columns + ] + + if stmt._has_multi_parameters: + stmt_parameters = stmt.parameters[0] + else: + stmt_parameters = stmt.parameters + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + _key_getters_for_crud_column(compiler) + + # if we have statement parameters - set defaults in the + # compiled params + if compiler.column_keys is None: + parameters = {} + else: + parameters = dict((_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or + key not in stmt_parameters) + + # create a list of column assignment clauses as tuples + values = [] + + if stmt_parameters is not None: + _get_stmt_parameters_params( + compiler, + parameters, stmt_parameters, _column_as_key, values, kw) + + check_columns = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if compiler.isupdate and stmt._extra_froms and stmt_parameters: + _get_multitable_params( + compiler, stmt, stmt_parameters, check_columns, + _col_bind_name, _getattr_col_key, values, kw) + + if compiler.isinsert and stmt.select_names: + # for an insert from select, we can only use names that + # are given, so only select for those names. + cols = (stmt.table.c[_column_as_key(name)] + for name in stmt.select_names) + else: + # iterate through all table columns to maintain + # ordering, even for those cols that aren't included + cols = stmt.table.columns + + _scan_cols( + compiler, stmt, cols, parameters, + _getattr_col_key, _col_bind_name, check_columns, values, kw) + + if parameters and stmt_parameters: + check = set(parameters).intersection( + _column_as_key(k) for k in stmt.parameters + ).difference(check_columns) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" % + (", ".join("%s" % c for c in check)) + ) + + if stmt._has_multi_parameters: + values = _extend_values_for_multiparams(compiler, stmt, values, kw) + + return values + + +def _create_bind_param(compiler, col, value, required=False, name=None): + if name is None: + name = col.key + bindparam = elements.BindParameter(name, value, + type_=col.type, required=required) + bindparam._is_crud = True + return bindparam._compiler_dispatch(compiler) + +def _key_getters_for_crud_column(compiler): + if compiler.isupdate and compiler.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(compiler.statement._extra_froms) + + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + +def _scan_cols( + compiler, stmt, cols, parameters, _getattr_col_key, + _col_bind_name, check_columns, values, kw): + + need_pks, implicit_returning, \ + implicit_return_defaults, postfetch_lastrowid = \ + _get_returning_modifiers(compiler, stmt) + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + + _append_param_parameter( + compiler, stmt, c, col_key, parameters, _col_bind_name, + implicit_returning, implicit_return_defaults, values, kw) + + elif compiler.isinsert: + if c.primary_key and \ + need_pks and \ + ( + implicit_returning or + not postfetch_lastrowid or + c is not stmt.table._autoincrement_column + ): + + if implicit_returning: + _append_param_insert_pk_returning( + compiler, stmt, c, values, kw) + else: + _append_param_insert_pk(compiler, stmt, c, values, kw) + + elif c.default is not None: + + _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw) + + elif c.server_default is not None: + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + + elif compiler.isupdate: + _append_param_update( + compiler, stmt, c, implicit_return_defaults, values, kw) + + +def _append_param_parameter( + compiler, stmt, c, col_key, parameters, _col_bind_name, + implicit_returning, implicit_return_defaults, values, kw): + value = parameters.pop(col_key) + if elements._is_literal(value): + value = _create_bind_param( + compiler, c, value, required=value is REQUIRED, + name=_col_bind_name(c) + if not stmt._has_multi_parameters + else "%s_0" % _col_bind_name(c) + ) + else: + if isinstance(value, elements.BindParameter) and \ + value.type._isnull: + value = value._clone() + value.type = c.type + + if c.primary_key and implicit_returning: + compiler.returning.append(c) + value = compiler.process(value.self_group(), **kw) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + value = compiler.process(value.self_group(), **kw) + else: + compiler.postfetch.append(c) + value = compiler.process(value.self_group(), **kw) + values.append((c, value)) + + +def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): + if c.default is not None: + if c.default.is_sequence: + if compiler.dialect.supports_sequences and \ + (not c.default.optional or + not compiler.dialect.sequences_optional): + proc = compiler.process(c.default, **kw) + values.append((c, proc)) + compiler.returning.append(c) + elif c.default.is_clause_element: + values.append( + (c, compiler.process( + c.default.arg.self_group(), **kw)) + ) + compiler.returning.append(c) + else: + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + compiler.prefetch.append(c) + else: + compiler.returning.append(c) + + +def _append_param_insert_pk(compiler, stmt, c, values, kw): + if ( + (c.default is not None and + (not c.default.is_sequence or + compiler.dialect.supports_sequences)) or + c is stmt.table._autoincrement_column and + (compiler.dialect.supports_sequences or + compiler.dialect. + preexecute_autoincrement_sequences) + ): + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + + compiler.prefetch.append(c) + + +def _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw): + + if c.default.is_sequence: + if compiler.dialect.supports_sequences and \ + (not c.default.optional or + not compiler.dialect.sequences_optional): + proc = compiler.process(c.default, **kw) + values.append((c, proc)) + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif c.default.is_clause_element: + values.append( + (c, compiler.process( + c.default.arg.self_group(), **kw)) + ) + + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + # don't add primary key column to postfetch + compiler.postfetch.append(c) + else: + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + compiler.prefetch.append(c) + + +def _append_param_update( + compiler, stmt, c, implicit_return_defaults, values, kw): + + if c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + (c, compiler.process( + c.onupdate.arg.self_group(), **kw)) + ) + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + else: + compiler.postfetch.append(c) + else: + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + compiler.prefetch.append(c) + elif c.server_onupdate is not None: + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + else: + compiler.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + + +def _get_multitable_params( + compiler, stmt, stmt_parameters, check_columns, + _col_bind_name, _getattr_col_key, values, kw): + + normalized_params = dict( + (elements._clause_element_as_expr(c), param) + for c, param in stmt_parameters.items() + ) + affected_tables = set() + for t in stmt._extra_froms: + for c in t.c: + if c in normalized_params: + affected_tables.add(t) + check_columns[_getattr_col_key(c)] = c + value = normalized_params[c] + if elements._is_literal(value): + value = _create_bind_param( + compiler, c, value, required=value is REQUIRED, + name=_col_bind_name(c)) + else: + compiler.postfetch.append(c) + value = compiler.process(value.self_group(), **kw) + values.append((c, value)) + # determine tables which are actually to be updated - process onupdate + # and server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in normalized_params: + continue + elif (c.onupdate is not None and not + c.onupdate.is_sequence): + if c.onupdate.is_clause_element: + values.append( + (c, compiler.process( + c.onupdate.arg.self_group(), + **kw) + ) + ) + compiler.postfetch.append(c) + else: + values.append( + (c, _create_bind_param( + compiler, c, None, name=_col_bind_name(c) + ) + ) + ) + compiler.prefetch.append(c) + elif c.server_onupdate is not None: + compiler.postfetch.append(c) + + +def _extend_values_for_multiparams(compiler, stmt, values, kw): + values_0 = values + values = [values] + + values.extend( + [ + ( + c, + (_create_bind_param( + compiler, c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) if elements._is_literal(row[c.key]) + else compiler.process( + row[c.key].self_group(), **kw)) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) + ) + return values + + +def _get_stmt_parameters_params( + compiler, parameters, stmt_parameters, _column_as_key, values, kw): + for k, v in stmt_parameters.items(): + colkey = _column_as_key(k) + if colkey is not None: + parameters.setdefault(colkey, v) + else: + # a non-Column expression on the left side; + # add it to values() in an "as-is" state, + # coercing right side to bound param + if elements._is_literal(v): + v = compiler.process( + elements.BindParameter(None, v, type_=k.type), + **kw) + else: + v = compiler.process(v.self_group(), **kw) + + values.append((k, v)) + + +def _get_returning_modifiers(compiler, stmt): + need_pks = compiler.isinsert and \ + not compiler.inline and \ + not stmt._returning and \ + not stmt._has_multi_parameters + + implicit_returning = need_pks and \ + compiler.dialect.implicit_returning and \ + stmt.table.implicit_returning + + if compiler.isinsert: + implicit_return_defaults = (implicit_returning and + stmt._return_defaults) + elif compiler.isupdate: + implicit_return_defaults = (compiler.dialect.implicit_returning and + stmt.table.implicit_returning and + stmt._return_defaults) + else: + implicit_return_defaults = False + + if implicit_return_defaults: + if stmt._return_defaults is True: + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults) + + postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid + + return need_pks, implicit_returning, \ + implicit_return_defaults, postfetch_lastrowid diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index d9fd37f92..26d7c428e 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1222,8 +1222,10 @@ class Column(SchemaItem, ColumnClause): existing = getattr(self, 'table', None) if existing is not None and existing is not table: raise exc.ArgumentError( - "Column object already assigned to Table '%s'" % - existing.description) + "Column object '%s' already assigned to Table '%s'" % ( + self.key, + existing.description + )) if self.key in table._columns: col = table._columns.get(self.key) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9e8cb3bc5..b4df87e54 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -746,6 +746,33 @@ class Join(FromClause): providing a "natural join". """ + constraints = cls._joincond_scan_left_right( + a, a_subset, b, consider_as_foreign_keys) + + if len(constraints) > 1: + cls._joincond_trim_constraints( + a, b, constraints, consider_as_foreign_keys) + + if len(constraints) == 0: + if isinstance(b, FromGrouping): + hint = " Perhaps you meant to convert the right side to a "\ + "subquery using alias()?" + else: + hint = "" + raise exc.NoForeignKeysError( + "Can't find any foreign key relationships " + "between '%s' and '%s'.%s" % + (a.description, b.description, hint)) + + crit = [(x == y) for x, y in list(constraints.values())[0]] + if len(crit) == 1: + return (crit[0]) + else: + return and_(*crit) + + @classmethod + def _joincond_scan_left_right( + cls, a, a_subset, b, consider_as_foreign_keys): constraints = collections.defaultdict(list) for left in (a_subset, a): @@ -780,57 +807,41 @@ class Join(FromClause): if nrte.table_name == b.name: raise else: - # this is totally covered. can't get - # coverage to mark it. continue if col is not None: constraints[fk.constraint].append((col, fk.parent)) if constraints: break + return constraints + @classmethod + def _joincond_trim_constraints( + cls, a, b, constraints, consider_as_foreign_keys): + # more than one constraint matched. narrow down the list + # to include just those FKCs that match exactly to + # "consider_as_foreign_keys". + if consider_as_foreign_keys: + for const in list(constraints): + if set(f.parent for f in const.elements) != set( + consider_as_foreign_keys): + del constraints[const] + + # if still multiple constraints, but + # they all refer to the exact same end result, use it. if len(constraints) > 1: - # more than one constraint matched. narrow down the list - # to include just those FKCs that match exactly to - # "consider_as_foreign_keys". - if consider_as_foreign_keys: - for const in list(constraints): - if set(f.parent for f in const.elements) != set( - consider_as_foreign_keys): - del constraints[const] - - # if still multiple constraints, but - # they all refer to the exact same end result, use it. - if len(constraints) > 1: - dedupe = set(tuple(crit) for crit in constraints.values()) - if len(dedupe) == 1: - key = list(constraints)[0] - constraints = {key: constraints[key]} - - if len(constraints) != 1: - raise exc.AmbiguousForeignKeysError( - "Can't determine join between '%s' and '%s'; " - "tables have more than one foreign key " - "constraint relationship between them. " - "Please specify the 'onclause' of this " - "join explicitly." % (a.description, b.description)) - - if len(constraints) == 0: - if isinstance(b, FromGrouping): - hint = " Perhaps you meant to convert the right side to a "\ - "subquery using alias()?" - else: - hint = "" - raise exc.NoForeignKeysError( - "Can't find any foreign key relationships " - "between '%s' and '%s'.%s" % - (a.description, b.description, hint)) - - crit = [(x == y) for x, y in list(constraints.values())[0]] - if len(crit) == 1: - return (crit[0]) - else: - return and_(*crit) + dedupe = set(tuple(crit) for crit in constraints.values()) + if len(dedupe) == 1: + key = list(constraints)[0] + constraints = {key: constraints[key]} + + if len(constraints) != 1: + raise exc.AmbiguousForeignKeysError( + "Can't determine join between '%s' and '%s'; " + "tables have more than one foreign key " + "constraint relationship between them. " + "Please specify the 'onclause' of this " + "join explicitly." % (a.description, b.description)) def select(self, whereclause=None, **kwargs): """Create a :class:`.Select` from this :class:`.Join`. @@ -2153,6 +2164,7 @@ class Select(HasPrefixes, GenerativeSelect): _prefixes = () _hints = util.immutabledict() + _statement_hints = () _distinct = False _from_cloned = None _correlate = () @@ -2525,10 +2537,30 @@ class Select(HasPrefixes, GenerativeSelect): return self._get_display_froms() + def with_statement_hint(self, text, dialect_name='*'): + """add a statement hint to this :class:`.Select`. + + This method is similar to :meth:`.Select.with_hint` except that + it does not require an individual table, and instead applies to the + statement as a whole. + + Hints here are specific to the backend database and may include + directives such as isolation levels, file directives, fetch directives, + etc. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Select.with_hint` + + """ + return self.with_hint(None, text, dialect_name) + @_generative def with_hint(self, selectable, text, dialect_name='*'): - """Add an indexing hint for the given selectable to this - :class:`.Select`. + """Add an indexing or other executional context hint for the given + selectable to this :class:`.Select`. The text of the hint is rendered in the appropriate location for the database backend in use, relative @@ -2555,9 +2587,16 @@ class Select(HasPrefixes, GenerativeSelect): mytable, "+ index(%(name)s ix_mytable)", 'oracle').\\ with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + .. seealso:: + + :meth:`.Select.with_statement_hint` + """ - self._hints = self._hints.union( - {(selectable, dialect_name): text}) + if selectable is None: + self._statement_hints += ((dialect_name, text), ) + else: + self._hints = self._hints.union( + {(selectable, dialect_name): text}) @property def type(self): diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 283d89e36..49211f805 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -133,7 +133,7 @@ class compound(object): name, fail._as_string(config), str(ex)))) break else: - raise ex + util.raise_from_cause(ex) def _expect_success(self, config, name='block'): if not self.fails: diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py new file mode 100644 index 000000000..497fcb7e5 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -0,0 +1,44 @@ +""" +Bootstrapper for nose/pytest plugins. + +The entire rationale for this system is to get the modules in plugin/ +imported without importing all of the supporting library, so that we can +set up things for testing before coverage starts. + +The rationale for all of plugin/ being *in* the supporting library in the +first place is so that the testing and plugin suite is available to other +libraries, mainly external SQLAlchemy and Alembic dialects, to make use +of the same test environment and standard suites available to +SQLAlchemy/Alembic themselves without the need to ship/install a separate +package outside of SQLAlchemy. + +NOTE: copied/adapted from SQLAlchemy master for backwards compatibility; +this should be removable when Alembic targets SQLAlchemy 1.0.0. + +""" + +import os +import sys + +bootstrap_file = locals()['bootstrap_file'] +to_bootstrap = locals()['to_bootstrap'] + + +def load_file_as_module(name): + path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) + if sys.version_info >= (3, 3): + from importlib import machinery + mod = machinery.SourceFileLoader(name, path).load_module() + else: + import imp + mod = imp.load_source(name, path) + return mod + +if to_bootstrap == "pytest": + sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") +elif to_bootstrap == "nose": + sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_noseplugin"] = load_file_as_module("noseplugin") +else: + raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py index 6ef539142..538087770 100644 --- a/lib/sqlalchemy/testing/plugin/noseplugin.py +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -12,6 +12,14 @@ way (e.g. as a package-less import). """ +try: + # installed by bootstrap.py + import sqla_plugin_base as plugin_base +except ImportError: + # assume we're a package, use traditional import + from . import plugin_base + + import os import sys @@ -19,16 +27,6 @@ from nose.plugins import Plugin fixtures = None py3k = sys.version_info >= (3, 0) -# no package imports yet! this prevents us from tripping coverage -# too soon. -path = os.path.join(os.path.dirname(__file__), "plugin_base.py") -if sys.version_info >= (3, 3): - from importlib import machinery - plugin_base = machinery.SourceFileLoader( - "plugin_base", path).load_module() -else: - import imp - plugin_base = imp.load_source("plugin_base", path) class NoseSQLAlchemy(Plugin): @@ -58,10 +56,10 @@ class NoseSQLAlchemy(Plugin): plugin_base.set_coverage_flag(options.enable_plugin_coverage) + def begin(self): global fixtures - from sqlalchemy.testing import fixtures + from sqlalchemy.testing import fixtures # noqa - def begin(self): plugin_base.post_begin() def describeTest(self, test): @@ -72,19 +70,23 @@ class NoseSQLAlchemy(Plugin): def wantMethod(self, fn): if py3k: + if not hasattr(fn.__self__, 'cls'): + return False cls = fn.__self__.cls else: cls = fn.im_class - print "METH:", fn, "CLS:", cls return plugin_base.want_method(cls, fn) def wantClass(self, cls): return plugin_base.want_class(cls) def beforeTest(self, test): - plugin_base.before_test(test, - test.test.cls.__module__, - test.test.cls, test.test.method.__name__) + if not hasattr(test.test, 'cls'): + return + plugin_base.before_test( + test, + test.test.cls.__module__, + test.test.cls, test.test.method.__name__) def afterTest(self, test): plugin_base.after_test(test) diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 7ba31d3e3..6696427dc 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -31,8 +31,6 @@ if py3k: else: import ConfigParser as configparser -FOLLOWER_IDENT = None - # late imports fixtures = None engines = None @@ -72,8 +70,6 @@ def setup_options(make_option): help="Drop all tables in the target database first") make_option("--backend-only", action="store_true", dest="backend_only", help="Run only tests marked with __backend__") - make_option("--mockpool", action="store_true", dest="mockpool", - help="Use mock pool (asserts only one connection used)") make_option("--low-connections", action="store_true", dest="low_connections", help="Use a low number of distinct connections - " @@ -95,14 +91,6 @@ def setup_options(make_option): make_option("--exclude-tag", action="callback", callback=_exclude_tag, type="string", help="Exclude tests with tag <tag>") - make_option("--serverside", action="store_true", - help="Turn on server side cursors for PG") - make_option("--mysql-engine", action="store", - dest="mysql_engine", default=None, - help="Use the specified MySQL storage engine for all tables, " - "default is a db-default/InnoDB combo.") - make_option("--tableopts", action="append", dest="tableopts", default=[], - help="Add a dialect-specific table option, key=value") make_option("--write-profiles", action="store_true", dest="write_profiles", default=False, help="Write/update profiling data.") @@ -115,8 +103,8 @@ def configure_follower(follower_ident): database creation. """ - global FOLLOWER_IDENT - FOLLOWER_IDENT = follower_ident + from sqlalchemy.testing import provision + provision.FOLLOWER_IDENT = follower_ident def memoize_important_follower_config(dict_): @@ -177,12 +165,14 @@ def post_begin(): global util, fixtures, engines, exclusions, \ assertions, warnings, profiling,\ config, testing - from sqlalchemy import testing - from sqlalchemy.testing import fixtures, engines, exclusions, \ - assertions, warnings, profiling, config - from sqlalchemy import util + from sqlalchemy import testing # noqa + from sqlalchemy.testing import fixtures, engines, exclusions # noqa + from sqlalchemy.testing import assertions, warnings, profiling # noqa + from sqlalchemy.testing import config # noqa + from sqlalchemy import util # noqa warnings.setup_filters() + def _log(opt_str, value, parser): global logging if not logging: @@ -234,12 +224,6 @@ def _setup_options(opt, file_config): @pre -def _server_side_cursors(options, file_config): - if options.serverside: - db_opts['server_side_cursors'] = True - - -@pre def _monkeypatch_cdecimal(options, file_config): if options.cdecimal: import cdecimal @@ -250,7 +234,7 @@ def _monkeypatch_cdecimal(options, file_config): def _engine_uri(options, file_config): from sqlalchemy.testing import config from sqlalchemy import testing - from sqlalchemy.testing.plugin import provision + from sqlalchemy.testing import provision if options.dburi: db_urls = list(options.dburi) @@ -273,20 +257,13 @@ def _engine_uri(options, file_config): for db_url in db_urls: cfg = provision.setup_config( - db_url, db_opts, options, file_config, FOLLOWER_IDENT) + db_url, db_opts, options, file_config, provision.FOLLOWER_IDENT) if not config._current: cfg.set_as_current(cfg, testing) @post -def _engine_pool(options, file_config): - if options.mockpool: - from sqlalchemy import pool - db_opts['poolclass'] = pool.AssertionPool - - -@post def _requirements(options, file_config): requirement_cls = file_config.get('sqla_testing', "requirement_cls") @@ -369,19 +346,6 @@ def _prep_testing_database(options, file_config): @post -def _set_table_options(options, file_config): - from sqlalchemy.testing import schema - - table_options = schema.table_options - for spec in options.tableopts: - key, value = spec.split('=') - table_options[key] = value - - if options.mysql_engine: - table_options['mysql_engine'] = options.mysql_engine - - -@post def _reverse_topological(options, file_config): if options.reversetop: from sqlalchemy.orm.util import randomize_unitofwork diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 005942913..4bbc8ed9a 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -1,7 +1,13 @@ +try: + # installed by bootstrap.py + import sqla_plugin_base as plugin_base +except ImportError: + # assume we're a package, use traditional import + from . import plugin_base + import pytest import argparse import inspect -from . import plugin_base import collections import itertools @@ -42,6 +48,8 @@ def pytest_configure(config): plugin_base.set_coverage_flag(bool(getattr(config.option, "cov_source", False))) + +def pytest_sessionstart(session): plugin_base.post_begin() if has_xdist: @@ -54,11 +62,11 @@ if has_xdist: plugin_base.memoize_important_follower_config(node.slaveinput) node.slaveinput["follower_ident"] = "test_%s" % next(_follower_count) - from . import provision + from sqlalchemy.testing import provision provision.create_follower_db(node.slaveinput["follower_ident"]) def pytest_testnodedown(node, error): - from . import provision + from sqlalchemy.testing import provision provision.drop_follower_db(node.slaveinput["follower_ident"]) diff --git a/lib/sqlalchemy/testing/plugin/provision.py b/lib/sqlalchemy/testing/provision.py index c6b9030f5..0bcdad959 100644 --- a/lib/sqlalchemy/testing/plugin/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -1,8 +1,10 @@ from sqlalchemy.engine import url as sa_url from sqlalchemy import text from sqlalchemy.util import compat -from .. import config, engines -import os +from . import config, engines + + +FOLLOWER_IDENT = None class register(object): diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index a04bcbbdd..da3e3128a 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -314,6 +314,20 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def temp_table_reflection(self): + return exclusions.open() + + @property + def temp_table_names(self): + """target dialect supports listing of temporary table names""" + return exclusions.closed() + + @property + def temporary_views(self): + """target database supports temporary views""" + return exclusions.closed() + + @property def index_reflection(self): return exclusions.open() diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py index df254520b..23d7a0a91 100644 --- a/lib/sqlalchemy/testing/runner.py +++ b/lib/sqlalchemy/testing/runner.py @@ -30,7 +30,7 @@ SQLAlchemy itself is possible. """ -from sqlalchemy.testing.plugin.noseplugin import NoseSQLAlchemy +from .plugin.noseplugin import NoseSQLAlchemy import nose diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 575a38db9..60db9eb47 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -95,6 +95,39 @@ class ComponentReflectionTest(fixtures.TablesTest): cls.define_index(metadata, users) if testing.requires.view_column_reflection.enabled: cls.define_views(metadata, schema) + if not schema and testing.requires.temp_table_reflection.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def define_temp_tables(cls, metadata): + # cheat a bit, we should fix this with some dialect-level + # temp table fixture + if testing.against("oracle"): + kw = { + 'prefixes': ["GLOBAL TEMPORARY"], + 'oracle_on_commit': 'PRESERVE ROWS' + } + else: + kw = { + 'prefixes': ["TEMPORARY"], + } + + user_tmp = Table( + "user_tmp", metadata, + Column("id", sa.INT, primary_key=True), + Column('name', sa.VARCHAR(50)), + Column('foo', sa.INT), + sa.UniqueConstraint('name', name='user_tmp_uq'), + sa.Index("user_tmp_ix", "foo"), + **kw + ) + if testing.requires.view_reflection.enabled and \ + testing.requires.temporary_views.enabled: + event.listen( + user_tmp, "after_create", + DDL("create temporary view user_tmp_v as " + "select * from user_tmp") + ) @classmethod def define_index(cls, metadata, users): @@ -147,6 +180,7 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses, dingalings = self.tables.users, \ self.tables.email_addresses, self.tables.dingalings insp = inspect(meta.bind) + if table_type == 'view': table_names = insp.get_view_names(schema) table_names.sort() @@ -162,6 +196,20 @@ class ComponentReflectionTest(fixtures.TablesTest): answer = ['dingalings', 'email_addresses', 'users'] eq_(sorted(table_names), answer) + @testing.requires.temp_table_names + def test_get_temp_table_names(self): + insp = inspect(testing.db) + temp_table_names = insp.get_temp_table_names() + eq_(sorted(temp_table_names), ['user_tmp']) + + @testing.requires.view_reflection + @testing.requires.temp_table_names + @testing.requires.temporary_views + def test_get_temp_view_names(self): + insp = inspect(self.metadata.bind) + temp_table_names = insp.get_temp_view_names() + eq_(sorted(temp_table_names), ['user_tmp_v']) + @testing.requires.table_reflection def test_get_table_names(self): self._test_get_table_names() @@ -294,6 +342,28 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_columns_with_schema(self): self._test_get_columns(schema=testing.config.test_schema) + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self): + meta = MetaData(testing.db) + user_tmp = self.tables.user_tmp + insp = inspect(meta.bind) + cols = insp.get_columns('user_tmp') + self.assert_(len(cols) > 0, len(cols)) + + for i, col in enumerate(user_tmp.columns): + eq_(col.name, cols[i]['name']) + + @testing.requires.temp_table_reflection + @testing.requires.view_column_reflection + @testing.requires.temporary_views + def test_get_temp_view_columns(self): + insp = inspect(self.metadata.bind) + cols = insp.get_columns('user_tmp_v') + eq_( + [col['name'] for col in cols], + ['id', 'name', 'foo'] + ) + @testing.requires.view_column_reflection def test_get_view_columns(self): self._test_get_columns(table_type='view') @@ -426,6 +496,26 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_unique_constraints(self): self._test_get_unique_constraints() + @testing.requires.temp_table_reflection + @testing.requires.unique_constraint_reflection + def test_get_temp_table_unique_constraints(self): + insp = inspect(self.metadata.bind) + eq_( + insp.get_unique_constraints('user_tmp'), + [{'column_names': ['name'], 'name': 'user_tmp_uq'}] + ) + + @testing.requires.temp_table_reflection + def test_get_temp_table_indexes(self): + insp = inspect(self.metadata.bind) + indexes = insp.get_indexes('user_tmp') + eq_( + # TODO: we need to add better filtering for indexes/uq constraints + # that are doubled up + [idx for idx in indexes if idx['name'] == 'user_tmp_ix'], + [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}] + ) + @testing.requires.unique_constraint_reflection @testing.requires.schemas def test_get_unique_constraints_with_schema(self): diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c963b18c3..dfed5b90a 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -33,7 +33,8 @@ from .langhelpers import iterate_attributes, class_hierarchy, \ duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ classproperty, set_creation_order, warn_exception, warn, NoneType,\ constructor_copy, methods_equivalent, chop_traceback, asint,\ - generic_repr, counter, PluginLoader, hybridmethod, safe_reraise,\ + generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \ + safe_reraise,\ get_callable_argspec, only_once, attrsetter, ellipses_string, \ warn_limited diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 76f85f605..95369783d 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1090,10 +1090,23 @@ class classproperty(property): return desc.fget(cls) +class hybridproperty(object): + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner): + if instance is None: + clsval = self.func(owner) + clsval.__doc__ = self.func.__doc__ + return clsval + else: + return self.func(instance) + + class hybridmethod(object): """Decorate a function as cls- or instance- level.""" - def __init__(self, func, expr=None): + def __init__(self, func): self.func = func def __get__(self, instance, owner): @@ -26,6 +26,13 @@ profile_file=test/profiles.txt # create database link test_link connect to scott identified by tiger using 'xe'; oracle_db_link = test_link +# host name of a postgres database that has the postgres_fdw extension. +# to create this run: +# CREATE EXTENSION postgres_fdw; +# GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO public; +# this can be localhost to create a loopback foreign table +# postgres_test_db_link = localhost + [db] default=sqlite:///:memory: diff --git a/sqla_nose.py b/sqla_nose.py index f89a1dce0..b977f4bf5 100755 --- a/sqla_nose.py +++ b/sqla_nose.py @@ -8,22 +8,25 @@ installs SQLAlchemy's testing plugin into the local environment. """ import sys import nose -import warnings +import os -from os import path for pth in ['./lib']: - sys.path.insert(0, path.join(path.dirname(path.abspath(__file__)), pth)) + sys.path.insert( + 0, os.path.join(os.path.dirname(os.path.abspath(__file__)), pth)) -# installing without importing SQLAlchemy, so that coverage includes -# SQLAlchemy itself. -path = "lib/sqlalchemy/testing/plugin/noseplugin.py" -if sys.version_info >= (3, 3): - from importlib import machinery - noseplugin = machinery.SourceFileLoader("noseplugin", path).load_module() -else: - import imp - noseplugin = imp.load_source("noseplugin", path) +# use bootstrapping so that test plugins are loaded +# without touching the main library before coverage starts +bootstrap_file = os.path.join( + os.path.dirname(__file__), "lib", "sqlalchemy", + "testing", "plugin", "bootstrap.py" +) +with open(bootstrap_file) as f: + code = compile(f.read(), "bootstrap.py", 'exec') + to_bootstrap = "nose" + exec(code, globals(), locals()) -nose.main(addplugins=[noseplugin.NoseSQLAlchemy()]) + +from noseplugin import NoseSQLAlchemy +nose.main(addplugins=[NoseSQLAlchemy()]) diff --git a/test/base/test_events.py b/test/base/test_events.py index 30b728cd3..89379961e 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -192,7 +192,7 @@ class EventsTest(fixtures.TestBase): class NamedCallTest(fixtures.TestBase): - def setUp(self): + def _fixture(self): class TargetEventsOne(event.Events): def event_one(self, x, y): pass @@ -205,48 +205,104 @@ class NamedCallTest(fixtures.TestBase): class TargetOne(object): dispatch = event.dispatcher(TargetEventsOne) - self.TargetOne = TargetOne + return TargetOne - def tearDown(self): - event.base._remove_dispatcher( - self.TargetOne.__dict__['dispatch'].events) + def _wrapped_fixture(self): + class TargetEvents(event.Events): + @classmethod + def _listen(cls, event_key): + fn = event_key._listen_fn + + def adapt(*args): + fn(*["adapted %s" % arg for arg in args]) + event_key = event_key.with_wrapper(adapt) + + event_key.base_listen() + + def event_one(self, x, y): + pass + + def event_five(self, x, y, z, q): + pass + + class Target(object): + dispatch = event.dispatcher(TargetEvents) + return Target def test_kw_accept(self): + TargetOne = self._fixture() + canary = Mock() - @event.listens_for(self.TargetOne, "event_one", named=True) + @event.listens_for(TargetOne, "event_one", named=True) def handler1(**kw): canary(kw) - self.TargetOne().dispatch.event_one(4, 5) + TargetOne().dispatch.event_one(4, 5) eq_( canary.mock_calls, [call({"x": 4, "y": 5})] ) + def test_kw_accept_wrapped(self): + TargetOne = self._wrapped_fixture() + + canary = Mock() + + @event.listens_for(TargetOne, "event_one", named=True) + def handler1(**kw): + canary(kw) + + TargetOne().dispatch.event_one(4, 5) + + eq_( + canary.mock_calls, + [call({'y': 'adapted 5', 'x': 'adapted 4'})] + ) + def test_partial_kw_accept(self): + TargetOne = self._fixture() + canary = Mock() - @event.listens_for(self.TargetOne, "event_five", named=True) + @event.listens_for(TargetOne, "event_five", named=True) def handler1(z, y, **kw): canary(z, y, kw) - self.TargetOne().dispatch.event_five(4, 5, 6, 7) + TargetOne().dispatch.event_five(4, 5, 6, 7) eq_( canary.mock_calls, [call(6, 5, {"x": 4, "q": 7})] ) + def test_partial_kw_accept_wrapped(self): + TargetOne = self._wrapped_fixture() + + canary = Mock() + + @event.listens_for(TargetOne, "event_five", named=True) + def handler1(z, y, **kw): + canary(z, y, kw) + + TargetOne().dispatch.event_five(4, 5, 6, 7) + + eq_( + canary.mock_calls, + [call('adapted 6', 'adapted 5', + {'q': 'adapted 7', 'x': 'adapted 4'})] + ) + def test_kw_accept_plus_kw(self): + TargetOne = self._fixture() canary = Mock() - @event.listens_for(self.TargetOne, "event_two", named=True) + @event.listens_for(TargetOne, "event_two", named=True) def handler1(**kw): canary(kw) - self.TargetOne().dispatch.event_two(4, 5, z=8, q=5) + TargetOne().dispatch.event_two(4, 5, z=8, q=5) eq_( canary.mock_calls, @@ -996,6 +1052,25 @@ class RemovalTest(fixtures.TestBase): dispatch = event.dispatcher(TargetEvents) return Target + def _wrapped_fixture(self): + class TargetEvents(event.Events): + @classmethod + def _listen(cls, event_key): + fn = event_key._listen_fn + + def adapt(value): + fn("adapted " + value) + event_key = event_key.with_wrapper(adapt) + + event_key.base_listen() + + def event_one(self, x): + pass + + class Target(object): + dispatch = event.dispatcher(TargetEvents) + return Target + def test_clslevel(self): Target = self._fixture() @@ -1194,3 +1269,71 @@ class RemovalTest(fixtures.TestBase): "deque mutated during iteration", t1.dispatch.event_one ) + + def test_remove_plain_named(self): + Target = self._fixture() + + listen_one = Mock() + t1 = Target() + event.listen(t1, "event_one", listen_one, named=True) + t1.dispatch.event_one("t1") + + eq_(listen_one.mock_calls, [call(x="t1")]) + event.remove(t1, "event_one", listen_one) + t1.dispatch.event_one("t2") + + eq_(listen_one.mock_calls, [call(x="t1")]) + + def test_remove_wrapped_named(self): + Target = self._wrapped_fixture() + + listen_one = Mock() + t1 = Target() + event.listen(t1, "event_one", listen_one, named=True) + t1.dispatch.event_one("t1") + + eq_(listen_one.mock_calls, [call(x="adapted t1")]) + event.remove(t1, "event_one", listen_one) + t1.dispatch.event_one("t2") + + eq_(listen_one.mock_calls, [call(x="adapted t1")]) + + def test_double_event_nonwrapped(self): + Target = self._fixture() + + listen_one = Mock() + t1 = Target() + event.listen(t1, "event_one", listen_one) + event.listen(t1, "event_one", listen_one) + + t1.dispatch.event_one("t1") + + # doubles are eliminated + eq_(listen_one.mock_calls, [call("t1")]) + + # only one remove needed + event.remove(t1, "event_one", listen_one) + t1.dispatch.event_one("t2") + + eq_(listen_one.mock_calls, [call("t1")]) + + def test_double_event_wrapped(self): + # this is issue #3199 + Target = self._wrapped_fixture() + + listen_one = Mock() + t1 = Target() + + event.listen(t1, "event_one", listen_one) + event.listen(t1, "event_one", listen_one) + + t1.dispatch.event_one("t1") + + # doubles are eliminated + eq_(listen_one.mock_calls, [call("adapted t1")]) + + # only one remove needed + event.remove(t1, "event_one", listen_one) + t1.dispatch.event_one("t2") + + eq_(listen_one.mock_calls, [call("adapted t1")]) diff --git a/test/conftest.py b/test/conftest.py index 1dd442309..c697085ee 100755 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,9 +7,23 @@ installs SQLAlchemy's testing plugin into the local environment. """ import sys +import os -from os import path for pth in ['../lib']: - sys.path.insert(0, path.join(path.dirname(path.abspath(__file__)), pth)) + sys.path.insert( + 0, + os.path.join(os.path.dirname(os.path.abspath(__file__)), pth)) -from sqlalchemy.testing.plugin.pytestplugin import * + +# use bootstrapping so that test plugins are loaded +# without touching the main library before coverage starts +bootstrap_file = os.path.join( + os.path.dirname(__file__), "..", "lib", "sqlalchemy", + "testing", "plugin", "bootstrap.py" +) + +with open(bootstrap_file) as f: + code = compile(f.read(), "bootstrap.py", 'exec') + to_bootstrap = "pytest" + exec(code, globals(), locals()) + from pytestplugin import * # noqa diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index 8ac9c6c16..4b4780d43 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -7,6 +7,8 @@ from sqlalchemy.engine import url from sqlalchemy.testing import fixtures from sqlalchemy import testing from sqlalchemy.testing import assert_raises_message, assert_warnings +from sqlalchemy.testing.mock import Mock + class ParseConnectTest(fixtures.TestBase): @@ -167,3 +169,21 @@ class ParseConnectTest(fixtures.TestBase): assert_raises_message(exc.SAWarning, 'Unrecognized server version info', engine.connect) + + +class VersionDetectionTest(fixtures.TestBase): + def test_pymssql_version(self): + dialect = pymssql.MSDialect_pymssql() + + for vers in [ + "Microsoft SQL Server Blah - 11.0.9216.62", + "Microsoft SQL Server (XYZ) - 11.0.9216.62 \n" + "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation", + "Microsoft SQL Azure (RTM) - 11.0.9216.62 \n" + "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation" + ]: + conn = Mock(scalar=Mock(return_value=vers)) + eq_( + dialect._get_server_version_info(conn), + (11, 0, 9216, 62) + )
\ No newline at end of file diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index bab41b0f7..b8b9be3de 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -13,8 +13,123 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql import base as postgresql -class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): +class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): + """Test reflection on foreign tables""" + + __requires__ = 'postgresql_test_dblink', + __only_on__ = 'postgresql >= 9.3' + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + from sqlalchemy.testing import config + dblink = config.file_config.get( + 'sqla_testing', 'postgres_test_db_link') + + testtable = Table( + 'testtable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + + for ddl in [ + "CREATE SERVER test_server FOREIGN DATA WRAPPER postgres_fdw " + "OPTIONS (dbname 'test', host '%s')" % dblink, + "CREATE USER MAPPING FOR public \ + SERVER test_server options (user 'scott', password 'tiger')", + "CREATE FOREIGN TABLE test_foreigntable ( " + " id INT, " + " data VARCHAR(30) " + ") SERVER test_server OPTIONS (table_name 'testtable')", + ]: + sa.event.listen(metadata, "after_create", sa.DDL(ddl)) + + for ddl in [ + 'DROP FOREIGN TABLE test_foreigntable', + 'DROP USER MAPPING FOR public SERVER test_server', + "DROP SERVER test_server" + ]: + sa.event.listen(metadata, "before_drop", sa.DDL(ddl)) + + def test_foreign_table_is_reflected(self): + metadata = MetaData(testing.db) + table = Table('test_foreigntable', metadata, autoload=True) + eq_(set(table.columns.keys()), set(['id', 'data']), + "Columns of reflected foreign table didn't equal expected columns") + def test_get_foreign_table_names(self): + inspector = inspect(testing.db) + with testing.db.connect() as conn: + ft_names = inspector.get_foreign_table_names() + eq_(ft_names, ['test_foreigntable']) + + def test_get_table_names_no_foreign(self): + inspector = inspect(testing.db) + with testing.db.connect() as conn: + names = inspector.get_table_names() + eq_(names, ['testtable']) + + +class MaterialiedViewReflectionTest( + fixtures.TablesTest, AssertsExecutionResults): + """Test reflection on materialized views""" + + __only_on__ = 'postgresql >= 9.3' + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + testtable = Table( + 'testtable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + + # insert data before we create the view + @sa.event.listens_for(testtable, "after_create") + def insert_data(target, connection, **kw): + connection.execute( + target.insert(), + {"id": 89, "data": 'd1'} + ) + + materialized_view = sa.DDL( + "CREATE MATERIALIZED VIEW test_mview AS " + "SELECT * FROM testtable") + + plain_view = sa.DDL( + "CREATE VIEW test_regview AS " + "SELECT * FROM testtable") + + sa.event.listen(testtable, 'after_create', plain_view) + sa.event.listen(testtable, 'after_create', materialized_view) + sa.event.listen( + testtable, 'before_drop', + sa.DDL("DROP MATERIALIZED VIEW test_mview") + ) + sa.event.listen( + testtable, 'before_drop', + sa.DDL("DROP VIEW test_regview") + ) + + def test_mview_is_reflected(self): + metadata = MetaData(testing.db) + table = Table('test_mview', metadata, autoload=True) + eq_(set(table.columns.keys()), set(['id', 'data']), + "Columns of reflected mview didn't equal expected columns") + + def test_mview_select(self): + metadata = MetaData(testing.db) + table = Table('test_mview', metadata, autoload=True) + eq_( + table.select().execute().fetchall(), + [(89, 'd1',)] + ) + + def test_get_view_names(self): + insp = inspect(testing.db) + eq_(set(insp.get_view_names()), set(['test_mview', 'test_regview'])) + + +class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): """Test PostgreSQL domains""" __only_on__ = 'postgresql > 8.3' diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index 187042036..36eacf864 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -648,6 +648,23 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE INDEX bar ON foo (x > 5)" ) + def test_table_options(self): + m = MetaData() + + t = Table( + 'foo', m, + Column('x', Integer), + prefixes=["GLOBAL TEMPORARY"], + oracle_on_commit="PRESERVE ROWS" + ) + + self.assert_compile( + schema.CreateTable(t), + "CREATE GLOBAL TEMPORARY TABLE " + "foo (x INTEGER) ON COMMIT PRESERVE ROWS" + ) + + class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL): def _dialect(self, server_version, **kw): diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index e77a03980..124208dbe 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -11,7 +11,7 @@ from sqlalchemy import Table, select, bindparam, Column,\ UniqueConstraint from sqlalchemy.types import Integer, String, Boolean, DateTime, Date, Time from sqlalchemy import types as sqltypes -from sqlalchemy import event +from sqlalchemy import event, inspect from sqlalchemy.util import u, ue from sqlalchemy import exc, sql, schema, pool, util from sqlalchemy.dialects.sqlite import base as sqlite, \ @@ -480,57 +480,6 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): assert u('méil') in result.keys() assert ue('\u6e2c\u8a66') in result.keys() - def test_attached_as_schema(self): - cx = testing.db.connect() - try: - cx.execute('ATTACH DATABASE ":memory:" AS test_schema') - dialect = cx.dialect - assert dialect.get_table_names(cx, 'test_schema') == [] - meta = MetaData(cx) - Table('created', meta, Column('id', Integer), - schema='test_schema') - alt_master = Table('sqlite_master', meta, autoload=True, - schema='test_schema') - meta.create_all(cx) - eq_(dialect.get_table_names(cx, 'test_schema'), ['created']) - assert len(alt_master.c) > 0 - meta.clear() - reflected = Table('created', meta, autoload=True, - schema='test_schema') - assert len(reflected.c) == 1 - cx.execute(reflected.insert(), dict(id=1)) - r = cx.execute(reflected.select()).fetchall() - assert list(r) == [(1, )] - cx.execute(reflected.update(), dict(id=2)) - r = cx.execute(reflected.select()).fetchall() - assert list(r) == [(2, )] - cx.execute(reflected.delete(reflected.c.id == 2)) - r = cx.execute(reflected.select()).fetchall() - assert list(r) == [] - - # note that sqlite_master is cleared, above - - meta.drop_all() - assert dialect.get_table_names(cx, 'test_schema') == [] - finally: - cx.execute('DETACH DATABASE test_schema') - - @testing.exclude('sqlite', '<', (2, 6), 'no database support') - def test_temp_table_reflection(self): - cx = testing.db.connect() - try: - cx.execute('CREATE TEMPORARY TABLE tempy (id INT)') - assert 'tempy' in cx.dialect.get_table_names(cx, None) - meta = MetaData(cx) - tempy = Table('tempy', meta, autoload=True) - assert len(tempy.c) == 1 - meta.drop_all() - except: - try: - cx.execute('DROP TABLE tempy') - except exc.DBAPIError: - pass - raise def test_file_path_is_absolute(self): d = pysqlite_dialect.dialect() @@ -549,7 +498,6 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): e = create_engine('sqlite+pysqlite:///foo.db') assert e.pool.__class__ is pool.NullPool - def test_dont_reflect_autoindex(self): meta = MetaData(testing.db) t = Table('foo', meta, Column('bar', String, primary_key=True)) @@ -575,6 +523,125 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): finally: meta.drop_all() + def test_get_unique_constraints(self): + meta = MetaData(testing.db) + t1 = Table('foo', meta, Column('f', Integer), + UniqueConstraint('f', name='foo_f')) + t2 = Table('bar', meta, Column('b', Integer), + UniqueConstraint('b', name='bar_b'), + prefixes=['TEMPORARY']) + meta.create_all() + from sqlalchemy.engine.reflection import Inspector + try: + inspector = Inspector(testing.db) + eq_(inspector.get_unique_constraints('foo'), + [{'column_names': [u'f'], 'name': u'foo_f'}]) + eq_(inspector.get_unique_constraints('bar'), + [{'column_names': [u'b'], 'name': u'bar_b'}]) + finally: + meta.drop_all() + + +class AttachedMemoryDBTest(fixtures.TestBase): + __only_on__ = 'sqlite' + + dbname = None + + def setUp(self): + self.conn = conn = testing.db.connect() + if self.dbname is None: + dbname = ':memory:' + else: + dbname = self.dbname + conn.execute('ATTACH DATABASE "%s" AS test_schema' % dbname) + self.metadata = MetaData() + + def tearDown(self): + self.metadata.drop_all(self.conn) + self.conn.execute('DETACH DATABASE test_schema') + if self.dbname: + os.remove(self.dbname) + + def _fixture(self): + meta = self.metadata + ct = Table( + 'created', meta, + Column('id', Integer), + Column('name', String), + schema='test_schema') + + meta.create_all(self.conn) + return ct + + def test_no_tables(self): + insp = inspect(self.conn) + eq_(insp.get_table_names("test_schema"), []) + + def test_table_names_present(self): + self._fixture() + insp = inspect(self.conn) + eq_(insp.get_table_names("test_schema"), ["created"]) + + def test_table_names_system(self): + self._fixture() + insp = inspect(self.conn) + eq_(insp.get_table_names("test_schema"), ["created"]) + + def test_reflect_system_table(self): + meta = MetaData(self.conn) + alt_master = Table( + 'sqlite_master', meta, autoload=True, + autoload_with=self.conn, + schema='test_schema') + assert len(alt_master.c) > 0 + + def test_reflect_user_table(self): + self._fixture() + + m2 = MetaData() + c2 = Table('created', m2, autoload=True, autoload_with=self.conn) + eq_(len(c2.c), 2) + + def test_crud(self): + ct = self._fixture() + + self.conn.execute(ct.insert(), {'id': 1, 'name': 'foo'}) + eq_( + self.conn.execute(ct.select()).fetchall(), + [(1, 'foo')] + ) + + self.conn.execute(ct.update(), {'id': 2, 'name': 'bar'}) + eq_( + self.conn.execute(ct.select()).fetchall(), + [(2, 'bar')] + ) + self.conn.execute(ct.delete()) + eq_( + self.conn.execute(ct.select()).fetchall(), + [] + ) + + def test_col_targeting(self): + ct = self._fixture() + + self.conn.execute(ct.insert(), {'id': 1, 'name': 'foo'}) + row = self.conn.execute(ct.select()).first() + eq_(row['id'], 1) + eq_(row['name'], 'foo') + + def test_col_targeting_union(self): + ct = self._fixture() + + self.conn.execute(ct.insert(), {'id': 1, 'name': 'foo'}) + row = self.conn.execute(ct.select().union(ct.select())).first() + eq_(row['id'], 1) + eq_(row['name'], 'foo') + + +class AttachedFileDBTest(AttachedMemoryDBTest): + dbname = 'attached_db.db' + class SQLTest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index d8e1c655e..219a145c6 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -982,6 +982,17 @@ class ExecutionOptionsTest(fixtures.TestBase): eq_(c1._execution_options, {"foo": "bar"}) eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"}) + def test_branched_connection_execution_options(self): + engine = testing_engine("sqlite://") + + conn = engine.connect() + c2 = conn.execution_options(foo="bar") + c2_branch = c2.connect() + eq_( + c2_branch._execution_options, + {"foo": "bar"} + ) + class AlternateResultProxyTest(fixtures.TestBase): __requires__ = ('sqlite', ) @@ -1440,6 +1451,48 @@ class EngineEventsTest(fixtures.TestBase): 'begin', 'execute', 'cursor_execute', 'commit', ]) + def test_transactional_named(self): + canary = [] + + def tracker(name): + def go(*args, **kw): + canary.append((name, set(kw))) + return go + + engine = engines.testing_engine() + event.listen(engine, 'before_execute', tracker('execute'), named=True) + event.listen( + engine, 'before_cursor_execute', + tracker('cursor_execute'), named=True) + event.listen(engine, 'begin', tracker('begin'), named=True) + event.listen(engine, 'commit', tracker('commit'), named=True) + event.listen(engine, 'rollback', tracker('rollback'), named=True) + + conn = engine.connect() + trans = conn.begin() + conn.execute(select([1])) + trans.rollback() + trans = conn.begin() + conn.execute(select([1])) + trans.commit() + + eq_( + canary, [ + ('begin', set(['conn', ])), + ('execute', set([ + 'conn', 'clauseelement', 'multiparams', 'params'])), + ('cursor_execute', set([ + 'conn', 'cursor', 'executemany', + 'statement', 'parameters', 'context'])), + ('rollback', set(['conn', ])), ('begin', set(['conn', ])), + ('execute', set([ + 'conn', 'clauseelement', 'multiparams', 'params'])), + ('cursor_execute', set([ + 'conn', 'cursor', 'executemany', 'statement', + 'parameters', 'context'])), + ('commit', set(['conn', ]))] + ) + @testing.requires.savepoints @testing.requires.two_phase_transactions def test_transactional_advanced(self): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index c82cca5a1..4500ada6a 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -8,7 +8,7 @@ from sqlalchemy import testing from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures from sqlalchemy.testing.engines import testing_engine -from sqlalchemy.testing.mock import Mock, call +from sqlalchemy.testing.mock import Mock, call, patch class MockError(Exception): @@ -504,6 +504,54 @@ class RealReconnectTest(fixtures.TestBase): # pool isn't replaced assert self.engine.pool is p2 + def test_branched_invalidate_branch_to_parent(self): + c1 = self.engine.connect() + + with patch.object(self.engine.pool, "logger") as logger: + c1_branch = c1.connect() + eq_(c1_branch.execute(select([1])).scalar(), 1) + + self.engine.test_shutdown() + + _assert_invalidated(c1_branch.execute, select([1])) + assert c1.invalidated + assert c1_branch.invalidated + + c1_branch._revalidate_connection() + assert not c1.invalidated + assert not c1_branch.invalidated + + assert "Invalidate connection" in logger.mock_calls[0][1][0] + + def test_branched_invalidate_parent_to_branch(self): + c1 = self.engine.connect() + + c1_branch = c1.connect() + eq_(c1_branch.execute(select([1])).scalar(), 1) + + self.engine.test_shutdown() + + _assert_invalidated(c1.execute, select([1])) + assert c1.invalidated + assert c1_branch.invalidated + + c1._revalidate_connection() + assert not c1.invalidated + assert not c1_branch.invalidated + + def test_branch_invalidate_state(self): + c1 = self.engine.connect() + + c1_branch = c1.connect() + + eq_(c1_branch.execute(select([1])).scalar(), 1) + + self.engine.test_shutdown() + + _assert_invalidated(c1_branch.execute, select([1])) + assert not c1_branch.closed + assert not c1_branch._connection_is_valid + def test_ensure_is_disconnect_gets_connection(self): def is_disconnect(e, conn, cursor): # connection is still present diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 8a5303642..b3b17e75a 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -133,6 +133,91 @@ class TransactionTest(fixtures.TestBase): finally: connection.close() + def test_branch_nested_rollback(self): + connection = testing.db.connect() + try: + connection.begin() + branched = connection.connect() + assert branched.in_transaction() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + assert not connection.in_transaction() + eq_(connection.scalar("select count(*) from query_users"), 0) + + finally: + connection.close() + + def test_branch_autorollback(self): + connection = testing.db.connect() + try: + branched = connection.connect() + branched.execute(users.insert(), user_id=1, user_name='user1') + try: + branched.execute(users.insert(), user_id=1, user_name='user1') + except exc.DBAPIError: + pass + finally: + connection.close() + + def test_branch_orig_rollback(self): + connection = testing.db.connect() + try: + branched = connection.connect() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin() + assert branched.in_transaction() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + eq_(connection.scalar("select count(*) from query_users"), 1) + + finally: + connection.close() + + def test_branch_autocommit(self): + connection = testing.db.connect() + try: + branched = connection.connect() + branched.execute(users.insert(), user_id=1, user_name='user1') + finally: + connection.close() + eq_(testing.db.scalar("select count(*) from query_users"), 1) + + @testing.requires.savepoints + def test_branch_savepoint_rollback(self): + connection = testing.db.connect() + try: + trans = connection.begin() + branched = connection.connect() + assert branched.in_transaction() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin_nested() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + assert connection.in_transaction() + trans.commit() + eq_(connection.scalar("select count(*) from query_users"), 1) + + finally: + connection.close() + + @testing.requires.two_phase_transactions + def test_branch_twophase_rollback(self): + connection = testing.db.connect() + try: + branched = connection.connect() + assert not branched.in_transaction() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin_twophase() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + assert not connection.in_transaction() + eq_(connection.scalar("select count(*) from query_users"), 1) + + finally: + connection.close() + def test_retains_through_options(self): connection = testing.db.connect() try: @@ -1126,139 +1211,6 @@ class TLTransactionTest(fixtures.TestBase): order_by(users.c.user_id)).fetchall(), [(1, ), (2, )]) -counters = None - - -class ForUpdateTest(fixtures.TestBase): - __requires__ = 'ad_hoc_engines', - __backend__ = True - - @classmethod - def setup_class(cls): - global counters, metadata - metadata = MetaData() - counters = Table('forupdate_counters', metadata, - Column('counter_id', INT, primary_key=True), - Column('counter_value', INT), - test_needs_acid=True) - counters.create(testing.db) - - def teardown(self): - testing.db.execute(counters.delete()).close() - - @classmethod - def teardown_class(cls): - counters.drop(testing.db) - - def increment(self, count, errors, update_style=True, delay=0.005): - con = testing.db.connect() - sel = counters.select(for_update=update_style, - whereclause=counters.c.counter_id == 1) - for i in range(count): - trans = con.begin() - try: - existing = con.execute(sel).first() - incr = existing['counter_value'] + 1 - time.sleep(delay) - con.execute(counters.update(counters.c.counter_id == 1, - values={'counter_value': incr})) - time.sleep(delay) - readback = con.execute(sel).first() - if readback['counter_value'] != incr: - raise AssertionError('Got %s post-update, expected ' - '%s' % (readback['counter_value'], incr)) - trans.commit() - except Exception as e: - trans.rollback() - errors.append(e) - break - con.close() - - @testing.crashes('mssql', 'FIXME: unknown') - @testing.crashes('firebird', 'FIXME: unknown') - @testing.crashes('sybase', 'FIXME: unknown') - @testing.requires.independent_connections - def test_queued_update(self): - """Test SELECT FOR UPDATE with concurrent modifications. - - Runs concurrent modifications on a single row in the users - table, with each mutator trying to increment a value stored in - user_name. - - """ - - db = testing.db - db.execute(counters.insert(), counter_id=1, counter_value=0) - iterations, thread_count = 10, 5 - threads, errors = [], [] - for i in range(thread_count): - thrd = threading.Thread(target=self.increment, - args=(iterations, ), - kwargs={'errors': errors, - 'update_style': True}) - thrd.start() - threads.append(thrd) - for thrd in threads: - thrd.join() - assert not errors - sel = counters.select(whereclause=counters.c.counter_id == 1) - final = db.execute(sel).first() - eq_(final['counter_value'], iterations * thread_count) - - def overlap(self, ids, errors, update_style): - - sel = counters.select(for_update=update_style, - whereclause=counters.c.counter_id.in_(ids)) - con = testing.db.connect() - trans = con.begin() - try: - rows = con.execute(sel).fetchall() - time.sleep(0.50) - trans.commit() - except Exception as e: - trans.rollback() - errors.append(e) - con.close() - - def _threaded_overlap(self, thread_count, groups, update_style=True, pool=5): - db = testing.db - for cid in range(pool - 1): - db.execute(counters.insert(), counter_id=cid + 1, - counter_value=0) - errors, threads = [], [] - for i in range(thread_count): - thrd = threading.Thread(target=self.overlap, - args=(groups.pop(0), errors, - update_style)) - time.sleep(0.20) # give the previous thread a chance to start - # to ensure it gets a lock - thrd.start() - threads.append(thrd) - for thrd in threads: - thrd.join() - return errors - - @testing.crashes('mssql', 'FIXME: unknown') - @testing.crashes('firebird', 'FIXME: unknown') - @testing.crashes('sybase', 'FIXME: unknown') - @testing.requires.independent_connections - def test_queued_select(self): - """Simple SELECT FOR UPDATE conflict test""" - - errors = self._threaded_overlap(2, [(1, 2, 3), (3, 4, 5)]) - assert not errors - - @testing.crashes('mssql', 'FIXME: unknown') - @testing.fails_on('mysql', 'No support for NOWAIT') - @testing.crashes('firebird', 'FIXME: unknown') - @testing.crashes('sybase', 'FIXME: unknown') - @testing.requires.independent_connections - def test_nowait_select(self): - """Simple SELECT FOR UPDATE NOWAIT conflict test""" - - errors = self._threaded_overlap(2, [(1, 2, 3), (3, 4, 5)], - update_style='nowait') - assert errors class IsolationLevelTest(fixtures.TestBase): __requires__ = ('isolation_level', 'ad_hoc_engines') diff --git a/test/ext/declarative/test_basic.py b/test/ext/declarative/test_basic.py index e2c2af679..3fac39cac 100644 --- a/test/ext/declarative/test_basic.py +++ b/test/ext/declarative/test_basic.py @@ -1,6 +1,6 @@ from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message, is_ + assert_raises_message from sqlalchemy.ext import declarative as decl from sqlalchemy import exc import sqlalchemy as sa @@ -10,21 +10,21 @@ from sqlalchemy import MetaData, Integer, String, ForeignKey, \ from sqlalchemy.testing.schema import Table, Column from sqlalchemy.orm import relationship, create_session, class_mapper, \ joinedload, configure_mappers, backref, clear_mappers, \ - deferred, column_property, composite,\ - Session, properties -from sqlalchemy.testing import eq_ -from sqlalchemy.util import classproperty, with_metaclass -from sqlalchemy.ext.declarative import declared_attr, AbstractConcreteBase, \ - ConcreteBase, synonym_for + column_property, composite, Session, properties +from sqlalchemy.util import with_metaclass +from sqlalchemy.ext.declarative import declared_attr, synonym_for from sqlalchemy.testing import fixtures -from sqlalchemy.testing.util import gc_collect Base = None +User = Address = None + + class DeclarativeTestBase(fixtures.TestBase, - testing.AssertsExecutionResults, - testing.AssertsCompiledSQL): + testing.AssertsExecutionResults, + testing.AssertsCompiledSQL): __dialect__ = 'default' + def setup(self): global Base Base = decl.declarative_base(testing.db) @@ -34,13 +34,15 @@ class DeclarativeTestBase(fixtures.TestBase, clear_mappers() Base.metadata.drop_all() + class DeclarativeTest(DeclarativeTestBase): + def test_basic(self): class User(Base, fixtures.ComparableEntity): __tablename__ = 'users' id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relationship("Address", backref="user") @@ -48,7 +50,7 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'addresses' id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) email = Column(String(50), key='_email') user_id = Column('user_id', Integer, ForeignKey('users.id'), key='_user_id') @@ -82,7 +84,7 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'users' id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relationship(util.u("Address"), backref="user") @@ -90,7 +92,7 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'addresses' id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) email = Column(String(50), key='_email') user_id = Column('user_id', Integer, ForeignKey('users.id'), key='_user_id') @@ -120,8 +122,10 @@ class DeclarativeTest(DeclarativeTestBase): __table_args__ = () def test_cant_add_columns(self): - t = Table('t', Base.metadata, Column('id', Integer, - primary_key=True), Column('data', String)) + t = Table( + 't', Base.metadata, + Column('id', Integer, primary_key=True), + Column('data', String)) def go(): class User(Base): @@ -158,7 +162,6 @@ class DeclarativeTest(DeclarativeTestBase): go ) - def test_column_repeated_under_prop(self): def go(): class Foo(Base): @@ -180,6 +183,7 @@ class DeclarativeTest(DeclarativeTestBase): class A(Base): __tablename__ = 'a' id = Column(Integer, primary_key=True) + class B(Base): __tablename__ = 'b' id = Column(Integer, primary_key=True) @@ -196,6 +200,7 @@ class DeclarativeTest(DeclarativeTestBase): class A(Base): __tablename__ = 'a' id = Column(Integer, primary_key=True) + class B(Base): __tablename__ = 'b' id = Column(Integer, primary_key=True) @@ -213,11 +218,12 @@ class DeclarativeTest(DeclarativeTestBase): # metaclass to mock the way zope.interface breaks getattr() class BrokenMeta(type): + def __getattribute__(self, attr): if attr == 'xyzzy': raise AttributeError('xyzzy') else: - return object.__getattribute__(self,attr) + return object.__getattribute__(self, attr) # even though this class has an xyzzy attribute, getattr(cls,"xyzzy") # fails @@ -225,13 +231,13 @@ class DeclarativeTest(DeclarativeTestBase): xyzzy = "magic" # _as_declarative() inspects obj.__class__.__bases__ - class User(BrokenParent,fixtures.ComparableEntity): + class User(BrokenParent, fixtures.ComparableEntity): __tablename__ = 'users' id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column('name', String(50)) - decl.instrument_declarative(User,{},Base.metadata) + decl.instrument_declarative(User, {}, Base.metadata) def test_reserved_identifiers(self): def go1(): @@ -285,29 +291,28 @@ class DeclarativeTest(DeclarativeTestBase): email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) user = relationship("User", primaryjoin=user_id == User.id, - backref="addresses") + backref="addresses") assert mapperlib.Mapper._new_mappers is True - u = User() + u = User() # noqa assert User.addresses assert mapperlib.Mapper._new_mappers is False def test_string_dependency_resolution(self): - from sqlalchemy.sql import desc - class User(Base, fixtures.ComparableEntity): __tablename__ = 'users' id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) - addresses = relationship('Address', - order_by='desc(Address.email)', - primaryjoin='User.id==Address.user_id', - foreign_keys='[Address.user_id]', - backref=backref('user', - primaryjoin='User.id==Address.user_id', - foreign_keys='[Address.user_id]')) + addresses = relationship( + 'Address', + order_by='desc(Address.email)', + primaryjoin='User.id==Address.user_id', + foreign_keys='[Address.user_id]', + backref=backref('user', + primaryjoin='User.id==Address.user_id', + foreign_keys='[Address.user_id]')) class Address(Base, fixtures.ComparableEntity): @@ -319,14 +324,17 @@ class DeclarativeTest(DeclarativeTestBase): Base.metadata.create_all() sess = create_session() - u1 = User(name='ed', addresses=[Address(email='abc'), - Address(email='def'), Address(email='xyz')]) + u1 = User( + name='ed', addresses=[ + Address(email='abc'), + Address(email='def'), Address(email='xyz')]) sess.add(u1) sess.flush() sess.expunge_all() eq_(sess.query(User).filter(User.name == 'ed').one(), - User(name='ed', addresses=[Address(email='xyz'), - Address(email='def'), Address(email='abc')])) + User(name='ed', addresses=[ + Address(email='xyz'), + Address(email='def'), Address(email='abc')])) class Foo(Base, fixtures.ComparableEntity): @@ -340,7 +348,6 @@ class DeclarativeTest(DeclarativeTestBase): "ColumnProperty", configure_mappers) def test_string_dependency_resolution_synonym(self): - from sqlalchemy.sql import desc class User(Base, fixtures.ComparableEntity): @@ -416,12 +423,13 @@ class DeclarativeTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) b_id = Column(ForeignKey('b.id')) - d = relationship("D", - secondary="join(B, D, B.d_id == D.id)." - "join(C, C.d_id == D.id)", - primaryjoin="and_(A.b_id == B.id, A.id == C.a_id)", - secondaryjoin="D.id == B.d_id", - ) + d = relationship( + "D", + secondary="join(B, D, B.d_id == D.id)." + "join(C, C.d_id == D.id)", + primaryjoin="and_(A.b_id == B.id, A.id == C.a_id)", + secondaryjoin="D.id == B.d_id", + ) class B(Base): __tablename__ = 'b' @@ -444,9 +452,9 @@ class DeclarativeTest(DeclarativeTestBase): self.assert_compile( s.query(A).join(A.d), "SELECT a.id AS a_id, a.b_id AS a_b_id FROM a JOIN " - "(b AS b_1 JOIN d AS d_1 ON b_1.d_id = d_1.id " - "JOIN c AS c_1 ON c_1.d_id = d_1.id) ON a.b_id = b_1.id " - "AND a.id = c_1.a_id JOIN d ON d.id = b_1.d_id", + "(b AS b_1 JOIN d AS d_1 ON b_1.d_id = d_1.id " + "JOIN c AS c_1 ON c_1.d_id = d_1.id) ON a.b_id = b_1.id " + "AND a.id = c_1.a_id JOIN d ON d.id = b_1.d_id", ) def test_string_dependency_resolution_no_table(self): @@ -474,6 +482,7 @@ class DeclarativeTest(DeclarativeTestBase): id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) + class Address(Base, fixtures.ComparableEntity): __tablename__ = 'addresses' @@ -481,7 +490,8 @@ class DeclarativeTest(DeclarativeTestBase): test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer) - user = relationship("User", + user = relationship( + "User", primaryjoin="remote(User.id)==foreign(Address.user_id)" ) @@ -497,9 +507,9 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'users' id = Column(Integer, primary_key=True) - addresses = relationship('Address', - primaryjoin='User.id==Address.user_id.prop.columns[' - '0]') + addresses = relationship( + 'Address', + primaryjoin='User.id==Address.user_id.prop.columns[0]') class Address(Base, fixtures.ComparableEntity): @@ -516,9 +526,10 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'users' id = Column(Integer, primary_key=True) - addresses = relationship('%s.Address' % __name__, - primaryjoin='%s.User.id==%s.Address.user_id.prop.columns[' - '0]' % (__name__, __name__)) + addresses = relationship( + '%s.Address' % __name__, + primaryjoin='%s.User.id==%s.Address.user_id.prop.columns[0]' + % (__name__, __name__)) class Address(Base, fixtures.ComparableEntity): @@ -538,8 +549,8 @@ class DeclarativeTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) name = Column(String(50)) addresses = relationship('Address', - primaryjoin='User.id==Address.user_id', - backref='user') + primaryjoin='User.id==Address.user_id', + backref='user') class Address(Base, fixtures.ComparableEntity): @@ -571,10 +582,11 @@ class DeclarativeTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) name = Column(String(50)) - user_to_prop = Table('user_to_prop', Base.metadata, - Column('user_id', Integer, - ForeignKey('users.id')), Column('prop_id', - Integer, ForeignKey('props.id'))) + user_to_prop = Table( + 'user_to_prop', Base.metadata, + Column('user_id', Integer, ForeignKey('users.id')), + Column('prop_id', Integer, ForeignKey('props.id'))) + configure_mappers() assert class_mapper(User).get_property('props').secondary \ is user_to_prop @@ -585,27 +597,29 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base): __tablename__ = 'users' - __table_args__ = {'schema':'fooschema'} + __table_args__ = {'schema': 'fooschema'} id = Column(Integer, primary_key=True) name = Column(String(50)) - props = relationship('Prop', secondary='fooschema.user_to_prop', - primaryjoin='User.id==fooschema.user_to_prop.c.user_id', - secondaryjoin='fooschema.user_to_prop.c.prop_id==Prop.id', - backref='users') + props = relationship( + 'Prop', secondary='fooschema.user_to_prop', + primaryjoin='User.id==fooschema.user_to_prop.c.user_id', + secondaryjoin='fooschema.user_to_prop.c.prop_id==Prop.id', + backref='users') class Prop(Base): __tablename__ = 'props' - __table_args__ = {'schema':'fooschema'} + __table_args__ = {'schema': 'fooschema'} id = Column(Integer, primary_key=True) name = Column(String(50)) - user_to_prop = Table('user_to_prop', Base.metadata, - Column('user_id', Integer, ForeignKey('fooschema.users.id')), - Column('prop_id',Integer, ForeignKey('fooschema.props.id')), - schema='fooschema') + user_to_prop = Table( + 'user_to_prop', Base.metadata, + Column('user_id', Integer, ForeignKey('fooschema.users.id')), + Column('prop_id', Integer, ForeignKey('fooschema.props.id')), + schema='fooschema') configure_mappers() assert class_mapper(User).get_property('props').secondary \ @@ -618,9 +632,11 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'parent' id = Column(Integer, primary_key=True) name = Column(String) - children = relationship("Child", - primaryjoin="Parent.name==remote(foreign(func.lower(Child.name_upper)))" - ) + children = relationship( + "Child", + primaryjoin="Parent.name==" + "remote(foreign(func.lower(Child.name_upper)))" + ) class Child(Base): __tablename__ = 'child' @@ -667,8 +683,8 @@ class DeclarativeTest(DeclarativeTestBase): test_needs_autoincrement=True) name = Column(String(50)) addresses = relationship('Address', order_by=Address.email, - foreign_keys=Address.user_id, - remote_side=Address.user_id) + foreign_keys=Address.user_id, + remote_side=Address.user_id) # get the mapper for User. User mapper will compile, # "addresses" relationship will call upon Address.user_id for @@ -681,14 +697,16 @@ class DeclarativeTest(DeclarativeTestBase): class_mapper(User) Base.metadata.create_all() sess = create_session() - u1 = User(name='ed', addresses=[Address(email='abc'), - Address(email='xyz'), Address(email='def')]) + u1 = User(name='ed', addresses=[ + Address(email='abc'), + Address(email='xyz'), Address(email='def')]) sess.add(u1) sess.flush() sess.expunge_all() eq_(sess.query(User).filter(User.name == 'ed').one(), - User(name='ed', addresses=[Address(email='abc'), - Address(email='def'), Address(email='xyz')])) + User(name='ed', addresses=[ + Address(email='abc'), + Address(email='def'), Address(email='xyz')])) def test_nice_dependency_error(self): @@ -726,14 +744,16 @@ class DeclarativeTest(DeclarativeTestBase): # the exception is preserved. Remains the # same through repeated calls. for i in range(3): - assert_raises_message(sa.exc.InvalidRequestError, - "^One or more mappers failed to initialize - " - "can't proceed with initialization of other " - "mappers. Original exception was: When initializing.*", - configure_mappers) + assert_raises_message( + sa.exc.InvalidRequestError, + "^One or more mappers failed to initialize - " + "can't proceed with initialization of other " + "mappers. Original exception was: When initializing.*", + configure_mappers) def test_custom_base(self): class MyBase(object): + def foobar(self): return "foobar" Base = decl.declarative_base(cls=MyBase) @@ -761,7 +781,7 @@ class DeclarativeTest(DeclarativeTestBase): Base.metadata.create_all() configure_mappers() assert class_mapper(Detail).get_property('master' - ).strategy.use_get + ).strategy.use_get m1 = Master() d1 = Detail(master=m1) sess = create_session() @@ -821,13 +841,15 @@ class DeclarativeTest(DeclarativeTestBase): eq_(Address.__table__.c['_email'].name, 'email') eq_(Address.__table__.c['_user_id'].name, 'user_id') u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User( + name='u1', + addresses=[Address(email='one'), Address(email='two')])]) a1 = sess.query(Address).filter(Address.email == 'two').one() eq_(a1, Address(email='two')) eq_(a1.user, User(name='u1')) @@ -842,7 +864,8 @@ class DeclarativeTest(DeclarativeTestBase): class ASub(A): brap = A.data assert ASub.brap.property is A.data.property - assert isinstance(ASub.brap.original_property, properties.SynonymProperty) + assert isinstance( + ASub.brap.original_property, properties.SynonymProperty) def test_alt_name_attr_subclass_relationship_inline(self): # [ticket:2900] @@ -857,10 +880,12 @@ class DeclarativeTest(DeclarativeTestBase): id = Column('id', Integer, primary_key=True) configure_mappers() + class ASub(A): brap = A.b assert ASub.brap.property is A.b.property - assert isinstance(ASub.brap.original_property, properties.SynonymProperty) + assert isinstance( + ASub.brap.original_property, properties.SynonymProperty) ASub(brap=B()) def test_alt_name_attr_subclass_column_attrset(self): @@ -881,6 +906,7 @@ class DeclarativeTest(DeclarativeTestBase): b_id = Column(Integer, ForeignKey('b.id')) b = relationship("B", backref="as_") A.brap = A.b + class B(Base): __tablename__ = 'b' id = Column('id', Integer, primary_key=True) @@ -889,7 +915,6 @@ class DeclarativeTest(DeclarativeTestBase): assert isinstance(A.brap.original_property, properties.SynonymProperty) A(brap=B()) - def test_eager_order_by(self): class Address(Base, fixtures.ComparableEntity): @@ -910,14 +935,14 @@ class DeclarativeTest(DeclarativeTestBase): Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='two'), - Address(email='one')]) + Address(email='one')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() eq_(sess.query(User).options(joinedload(User.addresses)).all(), [User(name='u1', addresses=[Address(email='one'), - Address(email='two')])]) + Address(email='two')])]) def test_order_by_multi(self): @@ -936,17 +961,17 @@ class DeclarativeTest(DeclarativeTestBase): test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relationship('Address', - order_by=(Address.email, Address.id)) + order_by=(Address.email, Address.id)) Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='two'), - Address(email='one')]) + Address(email='one')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() u = sess.query(User).filter(User.name == 'u1').one() - a = u.addresses + u.addresses def test_as_declarative(self): @@ -971,13 +996,15 @@ class DeclarativeTest(DeclarativeTestBase): decl.instrument_declarative(Address, reg, Base.metadata) Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User( + name='u1', + addresses=[Address(email='one'), Address(email='two')])]) def test_custom_mapper_attribute(self): @@ -1045,7 +1072,7 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'foo' __table_args__ = ForeignKeyConstraint(['id'], ['foo.id' - ]) + ]) id = Column('id', Integer, primary_key=True) assert_raises_message(sa.exc.ArgumentError, '__table_args__ value must be a tuple, ', err) @@ -1107,17 +1134,18 @@ class DeclarativeTest(DeclarativeTestBase): User.address_count = \ sa.orm.column_property(sa.select([sa.func.count(Address.id)]). - where(Address.user_id - == User.id).as_scalar()) + where(Address.user_id + == User.id).as_scalar()) Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', address_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(name='u1', address_count=2, + addresses=[Address(email='one'), Address(email='two')])]) def test_useless_declared_attr(self): class Address(Base, fixtures.ComparableEntity): @@ -1140,23 +1168,26 @@ class DeclarativeTest(DeclarativeTestBase): def address_count(cls): # this doesn't really gain us anything. but if # one is used, lets have it function as expected... - return sa.orm.column_property(sa.select([sa.func.count(Address.id)]). - where(Address.user_id == cls.id)) + return sa.orm.column_property( + sa.select([sa.func.count(Address.id)]). + where(Address.user_id == cls.id)) Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', address_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(name='u1', address_count=2, + addresses=[Address(email='one'), Address(email='two')])]) def test_declared_on_base_class(self): class MyBase(Base): __tablename__ = 'foo' id = Column(Integer, primary_key=True) + @declared_attr def somecol(cls): return Column(Integer) @@ -1213,18 +1244,19 @@ class DeclarativeTest(DeclarativeTestBase): adr_count = \ sa.orm.column_property( sa.select([sa.func.count(Address.id)], - Address.user_id == id).as_scalar()) + Address.user_id == id).as_scalar()) addresses = relationship(Address) Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', adr_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(name='u1', adr_count=2, + addresses=[Address(email='one'), Address(email='two')])]) def test_column_properties_2(self): @@ -1248,7 +1280,7 @@ class DeclarativeTest(DeclarativeTestBase): eq_(set(User.__table__.c.keys()), set(['id', 'name'])) eq_(set(Address.__table__.c.keys()), set(['id', 'email', - 'user_id'])) + 'user_id'])) def test_deferred(self): @@ -1274,86 +1306,91 @@ class DeclarativeTest(DeclarativeTestBase): def test_composite_inline(self): class AddressComposite(fixtures.ComparableEntity): + def __init__(self, street, state): self.street = street self.state = state + def __composite_values__(self): return [self.street, self.state] class User(Base, fixtures.ComparableEntity): __tablename__ = 'user' id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) address = composite(AddressComposite, - Column('street', String(50)), - Column('state', String(2)), - ) + Column('street', String(50)), + Column('state', String(2)), + ) Base.metadata.create_all() sess = Session() sess.add(User( - address=AddressComposite('123 anywhere street', - 'MD') - )) + address=AddressComposite('123 anywhere street', + 'MD') + )) sess.commit() eq_( sess.query(User).all(), [User(address=AddressComposite('123 anywhere street', - 'MD'))] + 'MD'))] ) def test_composite_separate(self): class AddressComposite(fixtures.ComparableEntity): + def __init__(self, street, state): self.street = street self.state = state + def __composite_values__(self): return [self.street, self.state] class User(Base, fixtures.ComparableEntity): __tablename__ = 'user' id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) street = Column(String(50)) state = Column(String(2)) address = composite(AddressComposite, - street, state) + street, state) Base.metadata.create_all() sess = Session() sess.add(User( - address=AddressComposite('123 anywhere street', - 'MD') - )) + address=AddressComposite('123 anywhere street', + 'MD') + )) sess.commit() eq_( sess.query(User).all(), [User(address=AddressComposite('123 anywhere street', - 'MD'))] + 'MD'))] ) def test_mapping_to_join(self): users = Table('users', Base.metadata, - Column('id', Integer, primary_key=True) - ) + Column('id', Integer, primary_key=True) + ) addresses = Table('addresses', Base.metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('users.id')) - ) + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('users.id')) + ) usersaddresses = sa.join(users, addresses, users.c.id == addresses.c.user_id) + class User(Base): __table__ = usersaddresses - __table_args__ = {'primary_key':[users.c.id]} + __table_args__ = {'primary_key': [users.c.id]} # need to use column_property for now user_id = column_property(users.c.id, addresses.c.user_id) address_id = addresses.c.id assert User.__mapper__.get_property('user_id').columns[0] \ - is users.c.id + is users.c.id assert User.__mapper__.get_property('user_id').columns[1] \ - is addresses.c.user_id + is addresses.c.user_id def test_synonym_inline(self): @@ -1372,7 +1409,7 @@ class DeclarativeTest(DeclarativeTestBase): name = sa.orm.synonym('_name', descriptor=property(_get_name, - _set_name)) + _set_name)) Base.metadata.create_all() sess = create_session() @@ -1381,7 +1418,7 @@ class DeclarativeTest(DeclarativeTestBase): sess.add(u1) sess.flush() eq_(sess.query(User).filter(User.name == 'SOMENAME someuser' - ).one(), u1) + ).one(), u1) def test_synonym_no_descriptor(self): from sqlalchemy.orm.properties import ColumnProperty @@ -1434,7 +1471,7 @@ class DeclarativeTest(DeclarativeTestBase): sess.add(u1) sess.flush() eq_(sess.query(User).filter(User.name == 'SOMENAME someuser' - ).one(), u1) + ).one(), u1) def test_reentrant_compile_via_foreignkey(self): @@ -1465,13 +1502,14 @@ class DeclarativeTest(DeclarativeTestBase): ) Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(name='u1', + addresses=[Address(email='one'), Address(email='two')])]) def test_relationship_reference(self): @@ -1490,21 +1528,22 @@ class DeclarativeTest(DeclarativeTestBase): test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relationship('Address', backref='user', - primaryjoin=id == Address.user_id) + primaryjoin=id == Address.user_id) User.address_count = \ sa.orm.column_property(sa.select([sa.func.count(Address.id)]). - where(Address.user_id - == User.id).as_scalar()) + where(Address.user_id + == User.id).as_scalar()) Base.metadata.create_all() u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', address_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(name='u1', address_count=2, + addresses=[Address(email='one'), Address(email='two')])]) def test_pk_with_fk_init(self): @@ -1526,9 +1565,11 @@ class DeclarativeTest(DeclarativeTestBase): def test_with_explicit_autoloaded(self): meta = MetaData(testing.db) - t1 = Table('t1', meta, Column('id', String(50), + t1 = Table( + 't1', meta, + Column('id', String(50), primary_key=True, test_needs_autoincrement=True), - Column('data', String(50))) + Column('data', String(50))) meta.create_all() try: @@ -1541,7 +1582,7 @@ class DeclarativeTest(DeclarativeTestBase): sess.add(m) sess.flush() eq_(t1.select().execute().fetchall(), [('someid', 'somedata' - )]) + )]) finally: meta.drop_all() @@ -1584,7 +1625,7 @@ class DeclarativeTest(DeclarativeTestBase): op, other, **kw - ): + ): return op(self.upperself, other, **kw) class User(Base, fixtures.ComparableEntity): @@ -1612,7 +1653,7 @@ class DeclarativeTest(DeclarativeTestBase): eq_(rt, u1) sess.expunge_all() rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE' - )).one() + )).one() eq_(rt, u1) def test_duplicate_classes_in_base(self): @@ -1631,7 +1672,6 @@ class DeclarativeTest(DeclarativeTestBase): ) - def _produce_test(inline, stringbased): class ExplicitJoinTest(fixtures.MappedTest): @@ -1657,35 +1697,43 @@ def _produce_test(inline, stringbased): user_id = Column(Integer, ForeignKey('users.id')) if inline: if stringbased: - user = relationship('User', - primaryjoin='User.id==Address.user_id', - backref='addresses') + user = relationship( + 'User', + primaryjoin='User.id==Address.user_id', + backref='addresses') else: user = relationship(User, primaryjoin=User.id - == user_id, backref='addresses') + == user_id, backref='addresses') if not inline: configure_mappers() if stringbased: - Address.user = relationship('User', - primaryjoin='User.id==Address.user_id', - backref='addresses') + Address.user = relationship( + 'User', + primaryjoin='User.id==Address.user_id', + backref='addresses') else: - Address.user = relationship(User, - primaryjoin=User.id == Address.user_id, - backref='addresses') + Address.user = relationship( + User, + primaryjoin=User.id == Address.user_id, + backref='addresses') @classmethod def insert_data(cls): - params = [dict(list(zip(('id', 'name'), column_values))) - for column_values in [(7, 'jack'), (8, 'ed'), (9, - 'fred'), (10, 'chuck')]] + params = [ + dict(list(zip(('id', 'name'), column_values))) + for column_values in [ + (7, 'jack'), (8, 'ed'), + (9, 'fred'), (10, 'chuck')]] + User.__table__.insert().execute(params) - Address.__table__.insert().execute([dict(list(zip(('id', - 'user_id', 'email'), column_values))) - for column_values in [(1, 7, 'jack@bean.com'), (2, - 8, 'ed@wood.com'), (3, 8, 'ed@bettyboop.com'), (4, - 8, 'ed@lala.com'), (5, 9, 'fred@fred.com')]]) + Address.__table__.insert().execute([ + dict(list(zip(('id', 'user_id', 'email'), column_values))) + for column_values in [ + (1, 7, 'jack@bean.com'), + (2, 8, 'ed@wood.com'), + (3, 8, 'ed@bettyboop.com'), + (4, 8, 'ed@lala.com'), (5, 9, 'fred@fred.com')]]) def test_aliased_join(self): @@ -1699,13 +1747,14 @@ def _produce_test(inline, stringbased): sess = create_session() eq_(sess.query(User).join(User.addresses, - aliased=True).filter(Address.email == 'ed@wood.com' - ).filter(User.addresses.any(Address.email - == 'jack@bean.com')).all(), []) - - ExplicitJoinTest.__name__ = 'ExplicitJoinTest%s%s' % (inline - and 'Inline' or 'Separate', stringbased and 'String' - or 'Literal') + aliased=True).filter( + Address.email == 'ed@wood.com').filter( + User.addresses.any(Address.email == 'jack@bean.com')).all(), + []) + + ExplicitJoinTest.__name__ = 'ExplicitJoinTest%s%s' % ( + inline and 'Inline' or 'Separate', + stringbased and 'String' or 'Literal') return ExplicitJoinTest for inline in True, False: @@ -1713,4 +1762,3 @@ for inline in True, False: testclass = _produce_test(inline, stringbased) exec('%s = testclass' % testclass.__name__) del testclass - diff --git a/test/ext/declarative/test_clsregistry.py b/test/ext/declarative/test_clsregistry.py index e78a1abbe..535fd00b3 100644 --- a/test/ext/declarative/test_clsregistry.py +++ b/test/ext/declarative/test_clsregistry.py @@ -5,7 +5,9 @@ from sqlalchemy import exc, MetaData from sqlalchemy.ext.declarative import clsregistry import weakref + class MockClass(object): + def __init__(self, base, name): self._decl_class_registry = base tokens = name.split(".") @@ -183,7 +185,7 @@ class ClsRegistryTest(fixtures.TestBase): f1 = MockClass(base, "foo.bar.Foo") clsregistry.add_class("Foo", f1) reg = base['_sa_module_registry'] - mod_entry = reg['foo']['bar'] + mod_entry = reg['foo']['bar'] # noqa resolver = clsregistry._resolver(f1, MockProp()) resolver = resolver("foo") assert_raises_message( @@ -232,4 +234,3 @@ class ClsRegistryTest(fixtures.TestBase): del f4 gc_collect() assert 'single' not in reg - diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index edff4421e..5a99c9c5a 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -10,12 +10,14 @@ from sqlalchemy.orm import relationship, create_session, class_mapper, \ configure_mappers, clear_mappers, \ polymorphic_union, deferred, Session from sqlalchemy.ext.declarative import declared_attr, AbstractConcreteBase, \ - ConcreteBase, has_inherited_table -from sqlalchemy.testing import fixtures + ConcreteBase, has_inherited_table +from sqlalchemy.testing import fixtures, mock Base = None + class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): + def setup(self): global Base Base = decl.declarative_base(testing.db) @@ -25,6 +27,7 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): clear_mappers() Base.metadata.drop_all() + class DeclarativeInheritanceTest(DeclarativeTestBase): def test_we_must_copy_mapper_args(self): @@ -65,7 +68,6 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): assert class_mapper(Person).version_id_col == 'a' assert class_mapper(Person).include_properties == set(['id', 'a', 'b']) - def test_custom_join_condition(self): class Foo(Base): @@ -123,21 +125,23 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Base.metadata.create_all() sess = create_session() - c1 = Company(name='MegaCorp, Inc.', - employees=[Engineer(name='dilbert', - primary_language='java'), Engineer(name='wally', - primary_language='c++'), Manager(name='dogbert', - golf_swing='fore!')]) + c1 = Company( + name='MegaCorp, Inc.', + employees=[ + Engineer(name='dilbert', primary_language='java'), + Engineer(name='wally', primary_language='c++'), + Manager(name='dogbert', golf_swing='fore!')]) + c2 = Company(name='Elbonia, Inc.', employees=[Engineer(name='vlad', - primary_language='cobol')]) + primary_language='cobol')]) sess.add(c1) sess.add(c2) sess.flush() sess.expunge_all() eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) + any(Engineer.primary_language + == 'cobol')).first(), c2) # ensure that the Manager mapper was compiled with the Manager id # column as higher priority. this ensures that "Manager.id" @@ -145,8 +149,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): # table (reversed from 0.6's behavior.) eq_( - Manager.id.property.columns, - [Manager.__table__.c.id, Person.__table__.c.id] + Manager.id.property.columns, + [Manager.__table__.c.id, Person.__table__.c.id] ) # assert that the "id" column is available without a second @@ -157,13 +161,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def go(): assert sess.query(Manager).filter(Manager.name == 'dogbert' - ).one().id + ).one().id self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): assert sess.query(Person).filter(Manager.name == 'dogbert' - ).one().id + ).one().id self.assert_sql_count(testing.db, go, 1) @@ -186,7 +190,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): primary_key=True) Engineer.primary_language = Column('primary_language', - String(50)) + String(50)) Base.metadata.create_all() sess = create_session() e1 = Engineer(primary_language='java', name='dilbert') @@ -194,7 +198,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): sess.flush() sess.expunge_all() eq_(sess.query(Person).first(), - Engineer(primary_language='java', name='dilbert')) + Engineer(primary_language='java', name='dilbert')) def test_add_parentcol_after_the_fact(self): @@ -258,8 +262,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Person).first(), Admin(primary_language='java', - name='dilbert', workstation='foo')) + eq_(sess.query(Person).first(), + Admin(primary_language='java', name='dilbert', workstation='foo')) def test_subclass_mixin(self): @@ -331,26 +335,25 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class PlanBooking(Booking): __tablename__ = 'plan_booking' id = Column(Integer, ForeignKey(Booking.id), - primary_key=True) + primary_key=True) # referencing PlanBooking.id gives us the column # on plan_booking, not booking class FeatureBooking(Booking): __tablename__ = 'feature_booking' id = Column(Integer, ForeignKey(Booking.id), - primary_key=True) + primary_key=True) plan_booking_id = Column(Integer, - ForeignKey(PlanBooking.id)) + ForeignKey(PlanBooking.id)) plan_booking = relationship(PlanBooking, - backref='feature_bookings') + backref='feature_bookings') assert FeatureBooking.__table__.c.plan_booking_id.\ - references(PlanBooking.__table__.c.id) + references(PlanBooking.__table__.c.id) assert FeatureBooking.__table__.c.id.\ - references(Booking.__table__.c.id) - + references(Booking.__table__.c.id) def test_single_colsonbase(self): """test single inheritance where all the columns are on the base @@ -387,23 +390,26 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Base.metadata.create_all() sess = create_session() - c1 = Company(name='MegaCorp, Inc.', - employees=[Engineer(name='dilbert', - primary_language='java'), Engineer(name='wally', - primary_language='c++'), Manager(name='dogbert', - golf_swing='fore!')]) + c1 = Company( + name='MegaCorp, Inc.', + employees=[ + Engineer(name='dilbert', primary_language='java'), + Engineer(name='wally', primary_language='c++'), + Manager(name='dogbert', golf_swing='fore!')]) + c2 = Company(name='Elbonia, Inc.', employees=[Engineer(name='vlad', - primary_language='cobol')]) + primary_language='cobol')]) sess.add(c1) sess.add(c2) sess.flush() sess.expunge_all() eq_(sess.query(Person).filter(Engineer.primary_language - == 'cobol').first(), Engineer(name='vlad')) + == 'cobol').first(), + Engineer(name='vlad')) eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) + any(Engineer.primary_language + == 'cobol')).first(), c2) def test_single_colsonsub(self): """test single inheritance where the columns are local to their @@ -470,15 +476,17 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): sess.flush() sess.expunge_all() eq_(sess.query(Person).filter(Engineer.primary_language - == 'cobol').first(), Engineer(name='vlad')) + == 'cobol').first(), + Engineer(name='vlad')) eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) + any(Engineer.primary_language + == 'cobol')).first(), c2) eq_(sess.query(Engineer).filter_by(primary_language='cobol' - ).one(), Engineer(name='vlad', primary_language='cobol')) + ).one(), + Engineer(name='vlad', primary_language='cobol')) @testing.skip_if(lambda: testing.against('oracle'), - "Test has an empty insert in it at the moment") + "Test has an empty insert in it at the moment") def test_columns_single_inheritance_conflict_resolution(self): """Test that a declared_attr can return the existing column and it will be ignored. this allows conditional columns to be added. @@ -491,25 +499,29 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) class Engineer(Person): + """single table inheritance""" @declared_attr def target_id(cls): - return cls.__table__.c.get('target_id', - Column(Integer, ForeignKey('other.id')) - ) + return cls.__table__.c.get( + 'target_id', + Column(Integer, ForeignKey('other.id'))) + @declared_attr def target(cls): return relationship("Other") class Manager(Person): + """single table inheritance""" @declared_attr def target_id(cls): - return cls.__table__.c.get('target_id', - Column(Integer, ForeignKey('other.id')) - ) + return cls.__table__.c.get( + 'target_id', + Column(Integer, ForeignKey('other.id'))) + @declared_attr def target(cls): return relationship("Other") @@ -534,11 +546,10 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Engineer(target=o1), Manager(target=o2), Manager(target=o1) - ]) + ]) session.commit() eq_(session.query(Engineer).first().target, o1) - def test_joined_from_single(self): class Company(Base, fixtures.ComparableEntity): @@ -595,12 +606,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): sess.expunge_all() eq_(sess.query(Person).with_polymorphic(Engineer). filter(Engineer.primary_language - == 'cobol').first(), Engineer(name='vlad')) + == 'cobol').first(), Engineer(name='vlad')) eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) + any(Engineer.primary_language + == 'cobol')).first(), c2) eq_(sess.query(Engineer).filter_by(primary_language='cobol' - ).one(), Engineer(name='vlad', primary_language='cobol')) + ).one(), + Engineer(name='vlad', primary_language='cobol')) def test_single_from_joined_colsonsub(self): class Person(Base, fixtures.ComparableEntity): @@ -661,7 +673,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): eq_(sess.query(Person).all(), [Person(name='ratbert')]) sess.expunge_all() person = sess.query(Person).filter(Person.name == 'ratbert' - ).one() + ).one() assert 'name' not in person.__dict__ def test_single_fksonsub(self): @@ -683,7 +695,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): __mapper_args__ = {'polymorphic_identity': 'engineer'} primary_language_id = Column(Integer, - ForeignKey('languages.id')) + ForeignKey('languages.id')) primary_language = relationship('Language') class Language(Base, fixtures.ComparableEntity): @@ -706,19 +718,19 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): sess.expunge_all() eq_(sess.query(Person).filter(Engineer.primary_language.has( Language.name - == 'cobol')).first(), Engineer(name='vlad', - primary_language=Language(name='cobol'))) + == 'cobol')).first(), + Engineer(name='vlad', primary_language=Language(name='cobol'))) eq_(sess.query(Engineer).filter(Engineer.primary_language.has( Language.name - == 'cobol')).one(), Engineer(name='vlad', - primary_language=Language(name='cobol'))) + == 'cobol')).one(), + Engineer(name='vlad', primary_language=Language(name='cobol'))) eq_(sess.query(Person).join(Engineer.primary_language).order_by( Language.name).all(), [Engineer(name='vlad', - primary_language=Language(name='cobol')), - Engineer(name='wally', primary_language=Language(name='cpp' - )), Engineer(name='dilbert', - primary_language=Language(name='java'))]) + primary_language=Language(name='cobol')), + Engineer(name='wally', primary_language=Language(name='cpp' + )), + Engineer(name='dilbert', primary_language=Language(name='java'))]) def test_single_three_levels(self): @@ -810,11 +822,11 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): __mapper_args__ = {'polymorphic_identity': 'engineer'} primary_language = Column('primary_language', - String(50)) + String(50)) foo_bar = Column(Integer, primary_key=True) assert_raises_message(sa.exc.ArgumentError, - 'place primary key', go) + 'place primary key', go) def test_single_no_table_args(self): @@ -832,7 +844,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): __mapper_args__ = {'polymorphic_identity': 'engineer'} primary_language = Column('primary_language', - String(50)) + String(50)) # this should be on the Person class, as this is single # table inheritance, which is why we test that this @@ -849,6 +861,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): __tablename__ = "a" id = Column(Integer, primary_key=True) a_1 = A + class A(a_1): __tablename__ = 'b' id = Column(Integer(), ForeignKey(a_1.id), primary_key=True) @@ -857,6 +870,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class OverlapColPrecedenceTest(DeclarativeTestBase): + """test #1892 cases when declarative does column precedence.""" def _run_test(self, Engineer, e_id, p_id): @@ -895,7 +909,7 @@ class OverlapColPrecedenceTest(DeclarativeTestBase): class Engineer(Person): __tablename__ = 'engineer' id = Column("eid", Integer, ForeignKey('person.id'), - primary_key=True) + primary_key=True) self._run_test(Engineer, "eid", "id") @@ -907,15 +921,18 @@ class OverlapColPrecedenceTest(DeclarativeTestBase): class Engineer(Person): __tablename__ = 'engineer' id = Column("eid", Integer, ForeignKey('person.pid'), - primary_key=True) + primary_key=True) self._run_test(Engineer, "eid", "pid") from test.orm.test_events import _RemoveListeners + + class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): + def _roundtrip(self, Employee, Manager, Engineer, Boss, - polymorphic=True, explicit_type=False): + polymorphic=True, explicit_type=False): Base.metadata.create_all() sess = create_session() e1 = Engineer(name='dilbert', primary_language='java') @@ -932,7 +949,7 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): assert_raises_message( AttributeError, "does not implement attribute .?'type' " - "at the instance level.", + "at the instance level.", getattr, obj, "type" ) else: @@ -946,37 +963,38 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): if polymorphic: eq_(sess.query(Employee).order_by(Employee.name).all(), [Engineer(name='dilbert'), Manager(name='dogbert'), - Boss(name='pointy haired'), Engineer(name='vlad'), Engineer(name='wally')]) + Boss(name='pointy haired'), + Engineer(name='vlad'), Engineer(name='wally')]) else: eq_(sess.query(Engineer).order_by(Engineer.name).all(), [Engineer(name='dilbert'), Engineer(name='vlad'), - Engineer(name='wally')]) + Engineer(name='wally')]) eq_(sess.query(Manager).all(), [Manager(name='dogbert')]) eq_(sess.query(Boss).all(), [Boss(name='pointy haired')]) - def test_explicit(self): - engineers = Table('engineers', Base.metadata, Column('id', - Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('primary_language', String(50))) + engineers = Table( + 'engineers', Base.metadata, + Column('id', + Integer, primary_key=True, test_needs_autoincrement=True), + Column('name', String(50)), + Column('primary_language', String(50))) managers = Table('managers', Base.metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('golf_swing', String(50)) - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(50)), + Column('golf_swing', String(50)) + ) boss = Table('boss', Base.metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('golf_swing', String(50)) - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(50)), + Column('golf_swing', String(50)) + ) punion = polymorphic_union({ - 'engineer': engineers, - 'manager': managers, - 'boss': boss}, 'type', 'punion') + 'engineer': engineers, + 'manager': managers, + 'boss': boss}, 'type', 'punion') class Employee(Base, fixtures.ComparableEntity): @@ -1047,31 +1065,31 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): class Manager(Employee): __tablename__ = 'manager' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'manager', - 'concrete': True} + 'polymorphic_identity': 'manager', + 'concrete': True} class Boss(Manager): __tablename__ = 'boss' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'boss', - 'concrete': True} + 'polymorphic_identity': 'boss', + 'concrete': True} class Engineer(Employee): __tablename__ = 'engineer' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) primary_language = Column(String(40)) __mapper_args__ = {'polymorphic_identity': 'engineer', - 'concrete': True} + 'concrete': True} self._roundtrip(Employee, Manager, Engineer, Boss) @@ -1079,42 +1097,42 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): class Employee(ConcreteBase, Base, fixtures.ComparableEntity): __tablename__ = 'employee' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) __mapper_args__ = { - 'polymorphic_identity': 'employee', - 'concrete': True} + 'polymorphic_identity': 'employee', + 'concrete': True} + class Manager(Employee): __tablename__ = 'manager' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'manager', - 'concrete': True} + 'polymorphic_identity': 'manager', + 'concrete': True} class Boss(Manager): __tablename__ = 'boss' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'boss', - 'concrete': True} + 'polymorphic_identity': 'boss', + 'concrete': True} class Engineer(Employee): __tablename__ = 'engineer' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) primary_language = Column(String(40)) __mapper_args__ = {'polymorphic_identity': 'engineer', - 'concrete': True} + 'concrete': True} self._roundtrip(Employee, Manager, Engineer, Boss) - def test_has_inherited_table_doesnt_consider_base(self): class A(Base): __tablename__ = 'a' @@ -1140,7 +1158,7 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): ret = { 'polymorphic_identity': 'default', 'polymorphic_on': cls.type, - } + } else: ret = {'polymorphic_identity': cls.__name__} return ret @@ -1161,7 +1179,7 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): class Manager(Employee): __tablename__ = 'manager' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) golf_swing = Column(String(40)) @@ -1170,13 +1188,13 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): return "manager" __mapper_args__ = { - 'polymorphic_identity': "manager", - 'concrete': True} + 'polymorphic_identity': "manager", + 'concrete': True} class Boss(Manager): __tablename__ = 'boss' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) golf_swing = Column(String(40)) @@ -1185,13 +1203,13 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): return "boss" __mapper_args__ = { - 'polymorphic_identity': "boss", - 'concrete': True} + 'polymorphic_identity': "boss", + 'concrete': True} class Engineer(Employee): __tablename__ = 'engineer' employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + test_needs_autoincrement=True) name = Column(String(50)) primary_language = Column(String(40)) @@ -1199,26 +1217,30 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): def type(self): return "engineer" __mapper_args__ = {'polymorphic_identity': "engineer", - 'concrete': True} + 'concrete': True} self._roundtrip(Employee, Manager, Engineer, Boss, explicit_type=True) -class ConcreteExtensionConfigTest(_RemoveListeners, testing.AssertsCompiledSQL, DeclarativeTestBase): + +class ConcreteExtensionConfigTest( + _RemoveListeners, testing.AssertsCompiledSQL, DeclarativeTestBase): __dialect__ = 'default' def test_classreg_setup(self): class A(Base, fixtures.ComparableEntity): __tablename__ = 'a' - id = Column(Integer, primary_key=True, test_needs_autoincrement=True) + id = Column(Integer, + primary_key=True, test_needs_autoincrement=True) data = Column(String(50)) collection = relationship("BC", primaryjoin="BC.a_id == A.id", - collection_class=set) + collection_class=set) class BC(AbstractConcreteBase, Base, fixtures.ComparableEntity): pass class B(BC): __tablename__ = 'b' - id = Column(Integer, primary_key=True, test_needs_autoincrement=True) + id = Column(Integer, + primary_key=True, test_needs_autoincrement=True) a_id = Column(Integer, ForeignKey('a.id')) data = Column(String(50)) @@ -1230,7 +1252,8 @@ class ConcreteExtensionConfigTest(_RemoveListeners, testing.AssertsCompiledSQL, class C(BC): __tablename__ = 'c' - id = Column(Integer, primary_key=True, test_needs_autoincrement=True) + id = Column(Integer, + primary_key=True, test_needs_autoincrement=True) a_id = Column(Integer, ForeignKey('a.id')) data = Column(String(50)) c_data = Column(String(50)) @@ -1274,8 +1297,94 @@ class ConcreteExtensionConfigTest(_RemoveListeners, testing.AssertsCompiledSQL, sess.query(A).join(A.collection), "SELECT a.id AS a_id, a.data AS a_data FROM a JOIN " "(SELECT c.id AS id, c.a_id AS a_id, c.data AS data, " - "c.c_data AS c_data, CAST(NULL AS VARCHAR(50)) AS b_data, " - "'c' AS type FROM c UNION ALL SELECT b.id AS id, b.a_id AS a_id, " - "b.data AS data, CAST(NULL AS VARCHAR(50)) AS c_data, " - "b.b_data AS b_data, 'b' AS type FROM b) AS pjoin ON pjoin.a_id = a.id" + "c.c_data AS c_data, CAST(NULL AS VARCHAR(50)) AS b_data, " + "'c' AS type FROM c UNION ALL SELECT b.id AS id, b.a_id AS a_id, " + "b.data AS data, CAST(NULL AS VARCHAR(50)) AS c_data, " + "b.b_data AS b_data, 'b' AS type FROM b) AS pjoin " + "ON pjoin.a_id = a.id" ) + + def test_prop_on_base(self): + """test [ticket:2670] """ + + counter = mock.Mock() + + class Something(Base): + __tablename__ = 'something' + id = Column(Integer, primary_key=True) + + class AbstractConcreteAbstraction(AbstractConcreteBase, Base): + id = Column(Integer, primary_key=True) + x = Column(Integer) + y = Column(Integer) + + @declared_attr + def something_id(cls): + return Column(ForeignKey(Something.id)) + + @declared_attr + def something(cls): + counter(cls, "something") + return relationship("Something") + + @declared_attr + def something_else(cls): + counter(cls, "something_else") + return relationship("Something") + + class ConcreteConcreteAbstraction(AbstractConcreteAbstraction): + __tablename__ = 'cca' + __mapper_args__ = { + 'polymorphic_identity': 'ccb', + 'concrete': True} + + # concrete is mapped, the abstract base is not (yet) + assert ConcreteConcreteAbstraction.__mapper__ + assert not hasattr(AbstractConcreteAbstraction, '__mapper__') + + session = Session() + self.assert_compile( + session.query(ConcreteConcreteAbstraction).filter( + ConcreteConcreteAbstraction.something.has(id=1)), + "SELECT cca.id AS cca_id, cca.x AS cca_x, cca.y AS cca_y, " + "cca.something_id AS cca_something_id FROM cca WHERE EXISTS " + "(SELECT 1 FROM something WHERE something.id = cca.something_id " + "AND something.id = :id_1)" + ) + + # now it is + assert AbstractConcreteAbstraction.__mapper__ + + self.assert_compile( + session.query(ConcreteConcreteAbstraction).filter( + ConcreteConcreteAbstraction.something_else.has(id=1)), + "SELECT cca.id AS cca_id, cca.x AS cca_x, cca.y AS cca_y, " + "cca.something_id AS cca_something_id FROM cca WHERE EXISTS " + "(SELECT 1 FROM something WHERE something.id = cca.something_id " + "AND something.id = :id_1)" + ) + + self.assert_compile( + session.query(AbstractConcreteAbstraction).filter( + AbstractConcreteAbstraction.something.has(id=1)), + "SELECT pjoin.id AS pjoin_id, pjoin.x AS pjoin_x, " + "pjoin.y AS pjoin_y, pjoin.something_id AS pjoin_something_id, " + "pjoin.type AS pjoin_type FROM " + "(SELECT cca.id AS id, cca.x AS x, cca.y AS y, " + "cca.something_id AS something_id, 'ccb' AS type FROM cca) " + "AS pjoin WHERE EXISTS (SELECT 1 FROM something " + "WHERE something.id = pjoin.something_id AND something.id = :id_1)" + ) + + self.assert_compile( + session.query(AbstractConcreteAbstraction).filter( + AbstractConcreteAbstraction.something_else.has(id=1)), + "SELECT pjoin.id AS pjoin_id, pjoin.x AS pjoin_x, " + "pjoin.y AS pjoin_y, pjoin.something_id AS pjoin_something_id, " + "pjoin.type AS pjoin_type FROM " + "(SELECT cca.id AS id, cca.x AS x, cca.y AS y, " + "cca.something_id AS something_id, 'ccb' AS type FROM cca) " + "AS pjoin WHERE EXISTS (SELECT 1 FROM something " + "WHERE something.id = pjoin.something_id AND something.id = :id_1)" + ) + diff --git a/test/ext/declarative/test_mixin.py b/test/ext/declarative/test_mixin.py index d3c2ff982..db86927a1 100644 --- a/test/ext/declarative/test_mixin.py +++ b/test/ext/declarative/test_mixin.py @@ -3,19 +3,21 @@ from sqlalchemy.testing import eq_, assert_raises, \ from sqlalchemy.ext import declarative as decl import sqlalchemy as sa from sqlalchemy import testing -from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy import Integer, String, ForeignKey, select, func from sqlalchemy.testing.schema import Table, Column from sqlalchemy.orm import relationship, create_session, class_mapper, \ configure_mappers, clear_mappers, \ - deferred, column_property, \ - Session + deferred, column_property, Session, base as orm_base from sqlalchemy.util import classproperty from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.testing import fixtures +from sqlalchemy.testing import fixtures, mock +from sqlalchemy.testing.util import gc_collect Base = None + class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): + def setup(self): global Base Base = decl.declarative_base(testing.db) @@ -25,6 +27,7 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): clear_mappers() Base.metadata.drop_all() + class DeclarativeMixinTest(DeclarativeTestBase): def test_simple(self): @@ -157,6 +160,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_table_name_inherited(self): class MyMixin: + @declared_attr def __tablename__(cls): return cls.__name__.lower() @@ -169,6 +173,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_classproperty_still_works(self): class MyMixin(object): + @classproperty def __tablename__(cls): return cls.__name__.lower() @@ -182,6 +187,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_table_name_not_inherited(self): class MyMixin: + @declared_attr def __tablename__(cls): return cls.__name__.lower() @@ -195,11 +201,13 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_table_name_inheritance_order(self): class MyMixin1: + @declared_attr def __tablename__(cls): return cls.__name__.lower() + '1' class MyMixin2: + @declared_attr def __tablename__(cls): return cls.__name__.lower() + '2' @@ -212,6 +220,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_table_name_dependent_on_subclass(self): class MyHistoryMixin: + @declared_attr def __tablename__(cls): return cls.parent_name + '_changelog' @@ -236,6 +245,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_table_args_inherited_descriptor(self): class MyMixin: + @declared_attr def __table_args__(cls): return {'info': cls.__name__} @@ -289,7 +299,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): assert Specific.bar.prop is General.bar.prop @testing.skip_if(lambda: testing.against('oracle'), - "Test has an empty insert in it at the moment") + "Test has an empty insert in it at the moment") def test_columns_single_inheritance_conflict_resolution(self): """Test that a declared_attr can return the existing column and it will be ignored. this allows conditional columns to be added. @@ -302,20 +312,24 @@ class DeclarativeMixinTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) class Mixin(object): + @declared_attr def target_id(cls): - return cls.__table__.c.get('target_id', - Column(Integer, ForeignKey('other.id')) - ) + return cls.__table__.c.get( + 'target_id', + Column(Integer, ForeignKey('other.id')) + ) @declared_attr def target(cls): return relationship("Other") class Engineer(Mixin, Person): + """single table inheritance""" class Manager(Mixin, Person): + """single table inheritance""" class Other(Base): @@ -338,11 +352,10 @@ class DeclarativeMixinTest(DeclarativeTestBase): Engineer(target=o1), Manager(target=o2), Manager(target=o1) - ]) + ]) session.commit() eq_(session.query(Engineer).first().target, o1) - def test_columns_joined_table_inheritance(self): """Test a column on a mixin with an alternate attribute name, mapped to a superclass and joined-table inheritance subclass. @@ -428,6 +441,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_mapper_args_declared_attr(self): class ComputedMapperArgs: + @declared_attr def __mapper_args__(cls): if cls.__name__ == 'Person': @@ -454,6 +468,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): # ComputedMapperArgs on both classes for no apparent reason. class ComputedMapperArgs: + @declared_attr def __mapper_args__(cls): if cls.__name__ == 'Person': @@ -612,7 +627,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): @declared_attr def __table_args__(cls): - return {'mysql_engine':'InnoDB'} + return {'mysql_engine': 'InnoDB'} @declared_attr def __mapper_args__(cls): @@ -640,13 +655,14 @@ class DeclarativeMixinTest(DeclarativeTestBase): """test the @declared_attr approach from a custom base.""" class Base(object): + @declared_attr def __tablename__(cls): return cls.__name__.lower() @declared_attr def __table_args__(cls): - return {'mysql_engine':'InnoDB'} + return {'mysql_engine': 'InnoDB'} @declared_attr def id(self): @@ -714,7 +730,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(Generic.__table__.name, 'generic') eq_(Specific.__table__.name, 'specific') eq_(list(Generic.__table__.c.keys()), ['timestamp', 'id', - 'python_type']) + 'python_type']) eq_(list(Specific.__table__.c.keys()), ['id']) eq_(Generic.__table__.kwargs, {'mysql_engine': 'InnoDB'}) eq_(Specific.__table__.kwargs, {'mysql_engine': 'InnoDB'}) @@ -749,7 +765,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(BaseType.__table__.name, 'basetype') eq_(list(BaseType.__table__.c.keys()), ['timestamp', 'type', 'id', - 'value']) + 'value']) eq_(BaseType.__table__.kwargs, {'mysql_engine': 'InnoDB'}) assert Single.__table__ is BaseType.__table__ eq_(Joined.__table__.name, 'joined') @@ -851,7 +867,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): @declared_attr def __tablename__(cls): if decl.has_inherited_table(cls) and TableNameMixin \ - not in cls.__bases__: + not in cls.__bases__: return None return cls.__name__.lower() @@ -900,9 +916,9 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Model(Base, ColumnMixin): - __table__ = Table('foo', Base.metadata, Column('data', - Integer), Column('id', Integer, - primary_key=True)) + __table__ = Table('foo', Base.metadata, + Column('data', Integer), + Column('id', Integer, primary_key=True)) model_col = Model.__table__.c.data mixin_col = ColumnMixin.data @@ -920,8 +936,8 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Model(Base, ColumnMixin): __table__ = Table('foo', Base.metadata, - Column('data',Integer), - Column('id', Integer,primary_key=True)) + Column('data', Integer), + Column('id', Integer, primary_key=True)) foo = relationship("Dest") assert_raises_message(sa.exc.ArgumentError, @@ -942,9 +958,9 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Model(Base, ColumnMixin): __table__ = Table('foo', Base.metadata, - Column('data',Integer), - Column('tada', Integer), - Column('id', Integer,primary_key=True)) + Column('data', Integer), + Column('tada', Integer), + Column('id', Integer, primary_key=True)) foo = relationship("Dest") assert_raises_message(sa.exc.ArgumentError, @@ -959,9 +975,9 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Model(Base, ColumnMixin): - __table__ = Table('foo', Base.metadata, Column('data', - Integer), Column('id', Integer, - primary_key=True)) + __table__ = Table('foo', Base.metadata, + Column('data', Integer), + Column('id', Integer, primary_key=True)) model_col = Model.__table__.c.data mixin_col = ColumnMixin.data @@ -987,10 +1003,11 @@ class DeclarativeMixinTest(DeclarativeTestBase): __tablename__ = 'model' eq_(list(Model.__table__.c.keys()), ['col1', 'col3', 'col2', 'col4', - 'id']) + 'id']) def test_honor_class_mro_one(self): class HasXMixin(object): + @declared_attr def x(self): return Column(Integer) @@ -1007,6 +1024,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_honor_class_mro_two(self): class HasXMixin(object): + @declared_attr def x(self): return Column(Integer) @@ -1014,6 +1032,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Parent(HasXMixin, Base): __tablename__ = 'parent' id = Column(Integer, primary_key=True) + def x(self): return "hi" @@ -1025,6 +1044,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): def test_arbitrary_attrs_one(self): class HasMixin(object): + @declared_attr def some_attr(cls): return cls.__name__ + "SOME ATTR" @@ -1043,8 +1063,9 @@ class DeclarativeMixinTest(DeclarativeTestBase): __tablename__ = 'filter_a' id = Column(Integer(), primary_key=True) parent_id = Column(Integer(), - ForeignKey('type_a.id')) + ForeignKey('type_a.id')) filter = Column(String()) + def __init__(self, filter_, **kw): self.filter = filter_ @@ -1052,16 +1073,18 @@ class DeclarativeMixinTest(DeclarativeTestBase): __tablename__ = 'filter_b' id = Column(Integer(), primary_key=True) parent_id = Column(Integer(), - ForeignKey('type_b.id')) + ForeignKey('type_b.id')) filter = Column(String()) + def __init__(self, filter_, **kw): self.filter = filter_ class FilterMixin(object): + @declared_attr def _filters(cls): return relationship(cls.filter_class, - cascade='all,delete,delete-orphan') + cascade='all,delete,delete-orphan') @declared_attr def filters(cls): @@ -1080,6 +1103,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): TypeA(filters=['foo']) TypeB(filters=['foo']) + class DeclarativeMixinPropertyTest(DeclarativeTestBase): def test_column_property(self): @@ -1118,9 +1142,9 @@ class DeclarativeMixinPropertyTest(DeclarativeTestBase): sess.add_all([m1, m2]) sess.flush() eq_(sess.query(MyModel).filter(MyModel.prop_hoho == 'foo' - ).one(), m1) + ).one(), m1) eq_(sess.query(MyOtherModel).filter(MyOtherModel.prop_hoho - == 'bar').one(), m2) + == 'bar').one(), m2) def test_doc(self): """test documentation transfer. @@ -1198,7 +1222,6 @@ class DeclarativeMixinPropertyTest(DeclarativeTestBase): ModelTwo.__table__.c.version_id ) - def test_deferred(self): class MyMixin(object): @@ -1235,8 +1258,8 @@ class DeclarativeMixinPropertyTest(DeclarativeTestBase): @declared_attr def target(cls): return relationship('Target', - primaryjoin='Target.id==%s.target_id' - % cls.__name__) + primaryjoin='Target.id==%s.target_id' + % cls.__name__) else: @declared_attr @@ -1279,7 +1302,199 @@ class DeclarativeMixinPropertyTest(DeclarativeTestBase): self._test_relationship(True) +class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): + __dialect__ = 'default' + + def test_singleton_behavior_within_decl(self): + counter = mock.Mock() + + class Mixin(object): + @declared_attr + def my_prop(cls): + counter(cls) + return Column('x', Integer) + + class A(Base, Mixin): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + + @declared_attr + def my_other_prop(cls): + return column_property(cls.my_prop + 5) + + eq_(counter.mock_calls, [mock.call(A)]) + + class B(Base, Mixin): + __tablename__ = 'b' + id = Column(Integer, primary_key=True) + + @declared_attr + def my_other_prop(cls): + return column_property(cls.my_prop + 5) + + eq_( + counter.mock_calls, + [mock.call(A), mock.call(B)]) + + # this is why we need singleton-per-class behavior. We get + # an un-bound "x" column otherwise here, because my_prop() generates + # multiple columns. + a_col = A.my_other_prop.__clause_element__().element.left + b_col = B.my_other_prop.__clause_element__().element.left + is_(a_col.table, A.__table__) + is_(b_col.table, B.__table__) + is_(a_col, A.__table__.c.x) + is_(b_col, B.__table__.c.x) + + s = Session() + self.assert_compile( + s.query(A), + "SELECT a.x AS a_x, a.x + :x_1 AS anon_1, a.id AS a_id FROM a" + ) + self.assert_compile( + s.query(B), + "SELECT b.x AS b_x, b.x + :x_1 AS anon_1, b.id AS b_id FROM b" + ) + + + def test_singleton_gc(self): + counter = mock.Mock() + + class Mixin(object): + @declared_attr + def my_prop(cls): + counter(cls.__name__) + return Column('x', Integer) + + class A(Base, Mixin): + __tablename__ = 'b' + id = Column(Integer, primary_key=True) + + @declared_attr + def my_other_prop(cls): + return column_property(cls.my_prop + 5) + + eq_(counter.mock_calls, [mock.call("A")]) + del A + gc_collect() + assert "A" not in Base._decl_class_registry + + def test_can_we_access_the_mixin_straight(self): + class Mixin(object): + @declared_attr + def my_prop(cls): + return Column('x', Integer) + + assert_raises_message( + sa.exc.SAWarning, + "Unmanaged access of declarative attribute my_prop " + "from non-mapped class Mixin", + getattr, Mixin, "my_prop" + ) + + def test_property_noncascade(self): + counter = mock.Mock() + + class Mixin(object): + @declared_attr + def my_prop(cls): + counter(cls) + return column_property(cls.x + 2) + + class A(Base, Mixin): + __tablename__ = 'a' + + id = Column(Integer, primary_key=True) + x = Column(Integer) + + class B(A): + pass + + eq_(counter.mock_calls, [mock.call(A)]) + + def test_property_cascade(self): + counter = mock.Mock() + + class Mixin(object): + @declared_attr.cascading + def my_prop(cls): + counter(cls) + return column_property(cls.x + 2) + + class A(Base, Mixin): + __tablename__ = 'a' + + id = Column(Integer, primary_key=True) + x = Column(Integer) + + class B(A): + pass + + eq_(counter.mock_calls, [mock.call(A), mock.call(B)]) + + def test_column_pre_map(self): + counter = mock.Mock() + + class Mixin(object): + @declared_attr + def my_col(cls): + counter(cls) + assert not orm_base._mapper_or_none(cls) + return Column('x', Integer) + + class A(Base, Mixin): + __tablename__ = 'a' + + id = Column(Integer, primary_key=True) + + eq_(counter.mock_calls, [mock.call(A)]) + + def test_mixin_attr_refers_to_column_copies(self): + # this @declared_attr can refer to User.id + # freely because we now do the "copy column" operation + # before the declared_attr is invoked. + + counter = mock.Mock() + + class HasAddressCount(object): + id = Column(Integer, primary_key=True) + + @declared_attr + def address_count(cls): + counter(cls.id) + return column_property( + select([func.count(Address.id)]). + where(Address.user_id == cls.id). + as_scalar() + ) + + class Address(Base): + __tablename__ = 'address' + id = Column(Integer, primary_key=True) + user_id = Column(ForeignKey('user.id')) + + class User(Base, HasAddressCount): + __tablename__ = 'user' + + eq_( + counter.mock_calls, + [mock.call(User.id)] + ) + + sess = Session() + self.assert_compile( + sess.query(User).having(User.address_count > 5), + 'SELECT (SELECT count(address.id) AS ' + 'count_1 FROM address WHERE address.user_id = "user".id) ' + 'AS anon_1, "user".id AS user_id FROM "user" ' + 'HAVING (SELECT count(address.id) AS ' + 'count_1 FROM address WHERE address.user_id = "user".id) ' + '> :param_1' + ) + + class AbstractTest(DeclarativeTestBase): + def test_abstract_boolean(self): class A(Base): diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index f4bda6995..c7f7bc05d 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -1,7 +1,7 @@ from sqlalchemy.testing import eq_, assert_raises from sqlalchemy.ext import declarative as decl from sqlalchemy import testing -from sqlalchemy import MetaData, Integer, String, ForeignKey +from sqlalchemy import Integer, String, ForeignKey from sqlalchemy.testing.schema import Table, Column from sqlalchemy.orm import relationship, create_session, \ clear_mappers, \ @@ -10,6 +10,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.util import gc_collect from sqlalchemy.ext.declarative.base import _DeferredMapperConfig + class DeclarativeReflectionBase(fixtures.TablesTest): __requires__ = 'reflectable_autoincrement', @@ -21,13 +22,14 @@ class DeclarativeReflectionBase(fixtures.TablesTest): super(DeclarativeReflectionBase, self).teardown() clear_mappers() + class DeclarativeReflectionTest(DeclarativeReflectionBase): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), + Column('id', Integer, + primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), test_needs_fk=True) Table( 'addresses', @@ -37,7 +39,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): Column('email', String(50)), Column('user_id', Integer, ForeignKey('users.id')), test_needs_fk=True, - ) + ) Table( 'imhandles', metadata, @@ -47,8 +49,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): Column('network', String(50)), Column('handle', String(50)), test_needs_fk=True, - ) - + ) def test_basic(self): class User(Base, fixtures.ComparableEntity): @@ -69,13 +70,14 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): test_needs_autoincrement=True) u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(name='u1', + addresses=[Address(email='one'), Address(email='two')])]) a1 = sess.query(Address).filter(Address.email == 'two').one() eq_(a1, Address(email='two')) eq_(a1.user, User(name='u1')) @@ -100,13 +102,14 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): test_needs_autoincrement=True) u1 = User(nom='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(nom='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(nom='u1', + addresses=[Address(email='one'), Address(email='two')])]) a1 = sess.query(Address).filter(Address.email == 'two').one() eq_(a1, Address(email='two')) eq_(a1.user, User(nom='u1')) @@ -131,61 +134,66 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): test_needs_autoincrement=True) handles = relationship('IMHandle', backref='user') - u1 = User(name='u1', handles=[IMHandle(network='blabber', - handle='foo'), IMHandle(network='lol', handle='zomg' - )]) + u1 = User(name='u1', handles=[ + IMHandle(network='blabber', handle='foo'), + IMHandle(network='lol', handle='zomg')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', - handles=[IMHandle(network='blabber', handle='foo'), - IMHandle(network='lol', handle='zomg')])]) + eq_(sess.query(User).all(), [ + User(name='u1', handles=[IMHandle(network='blabber', handle='foo'), + IMHandle(network='lol', handle='zomg')])]) a1 = sess.query(IMHandle).filter(IMHandle.handle == 'zomg' - ).one() + ).one() eq_(a1, IMHandle(network='lol', handle='zomg')) eq_(a1.user, User(name='u1')) + class DeferredReflectBase(DeclarativeReflectionBase): + def teardown(self): super(DeferredReflectBase, self).teardown() _DeferredMapperConfig._configs.clear() Base = None + class DeferredReflectPKFKTest(DeferredReflectBase): + @classmethod def define_tables(cls, metadata): Table("a", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - ) + Column('id', Integer, + primary_key=True, test_needs_autoincrement=True), + ) Table("b", metadata, - Column('id', Integer, - ForeignKey('a.id'), - primary_key=True), - Column('x', Integer, primary_key=True) - ) + Column('id', Integer, + ForeignKey('a.id'), + primary_key=True), + Column('x', Integer, primary_key=True) + ) def test_pk_fk(self): class B(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'b' a = relationship("A") class A(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'a' decl.DeferredReflection.prepare(testing.db) + class DeferredReflectionTest(DeferredReflectBase): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), + Column('id', Integer, + primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), test_needs_fk=True) Table( 'addresses', @@ -195,7 +203,7 @@ class DeferredReflectionTest(DeferredReflectBase): Column('email', String(50)), Column('user_id', Integer, ForeignKey('users.id')), test_needs_fk=True, - ) + ) def _roundtrip(self): @@ -203,25 +211,26 @@ class DeferredReflectionTest(DeferredReflectBase): Address = Base._decl_class_registry['Address'] u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + Address(email='two')]) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_(sess.query(User).all(), [ + User(name='u1', + addresses=[Address(email='one'), Address(email='two')])]) a1 = sess.query(Address).filter(Address.email == 'two').one() eq_(a1, Address(email='two')) eq_(a1.user, User(name='u1')) def test_basic_deferred(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'users' addresses = relationship("Address", backref="user") class Address(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'addresses' decl.DeferredReflection.prepare(testing.db) @@ -249,12 +258,12 @@ class DeferredReflectionTest(DeferredReflectBase): def test_redefine_fk_double(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'users' addresses = relationship("Address", backref="user") class Address(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'addresses' user_id = Column(Integer, ForeignKey('users.id')) @@ -262,10 +271,11 @@ class DeferredReflectionTest(DeferredReflectBase): self._roundtrip() def test_mapper_args_deferred(self): - """test that __mapper_args__ is not called until *after* table reflection""" + """test that __mapper_args__ is not called until *after* + table reflection""" class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'users' @decl.declared_attr @@ -296,10 +306,11 @@ class DeferredReflectionTest(DeferredReflectBase): @testing.requires.predictable_gc def test_cls_not_strong_ref(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'users' + class Address(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'addresses' eq_(len(_DeferredMapperConfig._configs), 2) del Address @@ -308,26 +319,28 @@ class DeferredReflectionTest(DeferredReflectBase): decl.DeferredReflection.prepare(testing.db) assert not _DeferredMapperConfig._configs + class DeferredSecondaryReflectionTest(DeferredReflectBase): + @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), + Column('id', Integer, + primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), test_needs_fk=True) Table('user_items', metadata, - Column('user_id', ForeignKey('users.id'), primary_key=True), - Column('item_id', ForeignKey('items.id'), primary_key=True), - test_needs_fk=True - ) + Column('user_id', ForeignKey('users.id'), primary_key=True), + Column('item_id', ForeignKey('items.id'), primary_key=True), + test_needs_fk=True + ) Table('items', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - test_needs_fk=True - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(50)), + test_needs_fk=True + ) def _roundtrip(self): @@ -340,8 +353,8 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase): sess.add(u1) sess.commit() - eq_(sess.query(User).all(), [User(name='u1', - items=[Item(name='i1'), Item(name='i2')])]) + eq_(sess.query(User).all(), [ + User(name='u1', items=[Item(name='i1'), Item(name='i2')])]) def test_string_resolution(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): @@ -359,7 +372,8 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'users' - items = relationship("Item", secondary=Table("user_items", Base.metadata)) + items = relationship("Item", + secondary=Table("user_items", Base.metadata)) class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = 'items' @@ -367,7 +381,9 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase): decl.DeferredReflection.prepare(testing.db) self._roundtrip() + class DeferredInhReflectBase(DeferredReflectBase): + def _roundtrip(self): Foo = Base._decl_class_registry['Foo'] Bar = Base._decl_class_registry['Bar'] @@ -392,24 +408,25 @@ class DeferredInhReflectBase(DeferredReflectBase): ] ) + class DeferredSingleInhReflectionTest(DeferredInhReflectBase): @classmethod def define_tables(cls, metadata): Table("foo", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(32)), - Column('data', String(30)), - Column('bar_data', String(30)) - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('type', String(32)), + Column('data', String(30)), + Column('bar_data', String(30)) + ) def test_basic(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'foo' __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + "polymorphic_identity": "foo"} class Bar(Foo): __mapper_args__ = {"polymorphic_identity": "bar"} @@ -419,10 +436,10 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): def test_add_subclass_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'foo' __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + "polymorphic_identity": "foo"} class Bar(Foo): __mapper_args__ = {"polymorphic_identity": "bar"} @@ -433,10 +450,10 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): def test_add_pk_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'foo' __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + "polymorphic_identity": "foo"} id = Column(Integer, primary_key=True) class Bar(Foo): @@ -445,28 +462,30 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): decl.DeferredReflection.prepare(testing.db) self._roundtrip() + class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): + @classmethod def define_tables(cls, metadata): Table("foo", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(32)), - Column('data', String(30)), - test_needs_fk=True, - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('type', String(32)), + Column('data', String(30)), + test_needs_fk=True, + ) Table('bar', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('bar_data', String(30)), - test_needs_fk=True, - ) + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('bar_data', String(30)), + test_needs_fk=True, + ) def test_basic(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'foo' __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + "polymorphic_identity": "foo"} class Bar(Foo): __tablename__ = 'bar' @@ -477,10 +496,10 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): def test_add_subclass_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'foo' __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + "polymorphic_identity": "foo"} class Bar(Foo): __tablename__ = 'bar' @@ -492,10 +511,10 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): def test_add_pk_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'foo' __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + "polymorphic_identity": "foo"} id = Column(Integer, primary_key=True) class Bar(Foo): @@ -507,10 +526,10 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): def test_add_fk_pk_column(self): class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): + Base): __tablename__ = 'foo' __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + "polymorphic_identity": "foo"} class Bar(Foo): __tablename__ = 'bar' diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index f24164cb7..0a57b9caa 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -1,13 +1,14 @@ -from sqlalchemy.testing import fixtures, eq_ +from sqlalchemy.testing import fixtures from ..orm._fixtures import FixtureTest from sqlalchemy.ext.automap import automap_base -from sqlalchemy.orm import relationship, interfaces, backref +from sqlalchemy.orm import relationship, interfaces, configure_mappers from sqlalchemy.ext.automap import generate_relationship -from sqlalchemy.testing.mock import Mock, call +from sqlalchemy.testing.mock import Mock from sqlalchemy import String, Integer, ForeignKey from sqlalchemy import testing from sqlalchemy.testing.schema import Table, Column + class AutomapTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): @@ -27,6 +28,7 @@ class AutomapTest(fixtures.MappedTest): def test_relationship_explicit_override_o2m(self): Base = automap_base(metadata=self.metadata) prop = relationship("addresses", collection_class=set) + class User(Base): __tablename__ = 'users' @@ -44,6 +46,7 @@ class AutomapTest(fixtures.MappedTest): Base = automap_base(metadata=self.metadata) prop = relationship("users") + class Address(Base): __tablename__ = 'addresses' @@ -57,7 +60,6 @@ class AutomapTest(fixtures.MappedTest): u1 = User(name='u1', address_collection=[a1]) assert a1.users is u1 - def test_relationship_self_referential(self): Base = automap_base(metadata=self.metadata) Base.prepare() @@ -75,17 +77,19 @@ class AutomapTest(fixtures.MappedTest): def classname_for_table(base, tablename, table): return str("cls_" + tablename) - def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + def name_for_scalar_relationship( + base, local_cls, referred_cls, constraint): return "scalar_" + referred_cls.__name__ - def name_for_collection_relationship(base, local_cls, referred_cls, constraint): + def name_for_collection_relationship( + base, local_cls, referred_cls, constraint): return "coll_" + referred_cls.__name__ Base.prepare( - classname_for_table=classname_for_table, - name_for_scalar_relationship=name_for_scalar_relationship, - name_for_collection_relationship=name_for_collection_relationship - ) + classname_for_table=classname_for_table, + name_for_scalar_relationship=name_for_scalar_relationship, + name_for_collection_relationship=name_for_collection_relationship + ) User = Base.classes.cls_users Address = Base.classes.cls_addresses @@ -113,9 +117,10 @@ class AutomapTest(fixtures.MappedTest): class Order(Base): __tablename__ = 'orders' - items_collection = relationship("items", - secondary="order_items", - collection_class=set) + items_collection = relationship( + "items", + secondary="order_items", + collection_class=set) Base.prepare() Item = Base.classes['items'] @@ -133,41 +138,115 @@ class AutomapTest(fixtures.MappedTest): Base = automap_base(metadata=self.metadata) mock = Mock() - def _gen_relationship(base, direction, return_fn, attrname, - local_cls, referred_cls, **kw): + + def _gen_relationship( + base, direction, return_fn, attrname, + local_cls, referred_cls, **kw): mock(base, direction, attrname) - return generate_relationship(base, direction, return_fn, - attrname, local_cls, referred_cls, **kw) + return generate_relationship( + base, direction, return_fn, + attrname, local_cls, referred_cls, **kw) Base.prepare(generate_relationship=_gen_relationship) assert set(tuple(c[1]) for c in mock.mock_calls).issuperset([ - (Base, interfaces.MANYTOONE, "nodes"), - (Base, interfaces.MANYTOMANY, "keywords_collection"), - (Base, interfaces.MANYTOMANY, "items_collection"), - (Base, interfaces.MANYTOONE, "users"), - (Base, interfaces.ONETOMANY, "addresses_collection"), + (Base, interfaces.MANYTOONE, "nodes"), + (Base, interfaces.MANYTOMANY, "keywords_collection"), + (Base, interfaces.MANYTOMANY, "items_collection"), + (Base, interfaces.MANYTOONE, "users"), + (Base, interfaces.ONETOMANY, "addresses_collection"), ]) +class CascadeTest(fixtures.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table( + "a", metadata, + Column('id', Integer, primary_key=True) + ) + Table( + "b", metadata, + Column('id', Integer, primary_key=True), + Column('aid', ForeignKey('a.id'), nullable=True) + ) + Table( + "c", metadata, + Column('id', Integer, primary_key=True), + Column('aid', ForeignKey('a.id'), nullable=False) + ) + Table( + "d", metadata, + Column('id', Integer, primary_key=True), + Column( + 'aid', ForeignKey('a.id', ondelete="cascade"), nullable=False) + ) + Table( + "e", metadata, + Column('id', Integer, primary_key=True), + Column( + 'aid', ForeignKey('a.id', ondelete="set null"), + nullable=True) + ) + + def test_o2m_relationship_cascade(self): + Base = automap_base(metadata=self.metadata) + Base.prepare() + + configure_mappers() + + b_rel = Base.classes.a.b_collection + assert not b_rel.property.cascade.delete + assert not b_rel.property.cascade.delete_orphan + assert not b_rel.property.passive_deletes + + assert b_rel.property.cascade.save_update + + c_rel = Base.classes.a.c_collection + assert c_rel.property.cascade.delete + assert c_rel.property.cascade.delete_orphan + assert not c_rel.property.passive_deletes + + assert c_rel.property.cascade.save_update + + d_rel = Base.classes.a.d_collection + assert d_rel.property.cascade.delete + assert d_rel.property.cascade.delete_orphan + assert d_rel.property.passive_deletes + + assert d_rel.property.cascade.save_update + + e_rel = Base.classes.a.e_collection + assert not e_rel.property.cascade.delete + assert not e_rel.property.cascade.delete_orphan + assert e_rel.property.passive_deletes + + assert e_rel.property.cascade.save_update + + class AutomapInhTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('single', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(10)), - test_needs_fk=True - ) - - Table('joined_base', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(10)), - test_needs_fk=True - ) - - Table('joined_inh', metadata, - Column('id', Integer, ForeignKey('joined_base.id'), primary_key=True), - test_needs_fk=True - ) + Table( + 'single', metadata, + Column('id', Integer, primary_key=True), + Column('type', String(10)), + test_needs_fk=True + ) + + Table( + 'joined_base', metadata, + Column('id', Integer, primary_key=True), + Column('type', String(10)), + test_needs_fk=True + ) + + Table( + 'joined_inh', metadata, + Column( + 'id', Integer, + ForeignKey('joined_base.id'), primary_key=True), + test_needs_fk=True + ) FixtureTest.define_tables(metadata) @@ -179,7 +258,8 @@ class AutomapInhTest(fixtures.MappedTest): type = Column(String) - __mapper_args__ = {"polymorphic_identity": "u0", + __mapper_args__ = { + "polymorphic_identity": "u0", "polymorphic_on": type} class SubUser1(Single): @@ -200,14 +280,14 @@ class AutomapInhTest(fixtures.MappedTest): type = Column(String) - __mapper_args__ = {"polymorphic_identity": "u0", + __mapper_args__ = { + "polymorphic_identity": "u0", "polymorphic_on": type} class SubJoined(Joined): __tablename__ = 'joined_inh' __mapper_args__ = {"polymorphic_identity": "u1"} - Base.prepare(engine=testing.db, reflect=True) assert SubJoined.__mapper__.inherits is Joined.__mapper__ @@ -217,6 +297,9 @@ class AutomapInhTest(fixtures.MappedTest): def test_conditional_relationship(self): Base = automap_base() + def _gen_relationship(*arg, **kw): return None - Base.prepare(engine=testing.db, reflect=True, generate_relationship=_gen_relationship) + Base.prepare( + engine=testing.db, reflect=True, + generate_relationship=_gen_relationship) diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py index 3223c8048..0eba137e7 100644 --- a/test/ext/test_orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -349,6 +349,28 @@ class OrderingListTest(fixtures.TestBase): self.assert_(srt.bullets[1].text == 'new 2') self.assert_(srt.bullets[2].text == '3') + def test_replace_two(self): + """test #3191""" + + self._setup(ordering_list('position', reorder_on_append=True)) + + s1 = Slide('Slide #1') + + b1, b2, b3, b4 = Bullet('1'), Bullet('2'), Bullet('3'), Bullet('4') + s1.bullets = [b1, b2, b3] + + eq_( + [b.position for b in s1.bullets], + [0, 1, 2] + ) + + s1.bullets = [b4, b2, b1] + eq_( + [b.position for b in s1.bullets], + [0, 1, 2] + ) + + def test_funky_ordering(self): class Pos(object): def __init__(self): diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 46d5f86e5..9c1f7a985 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -2522,6 +2522,53 @@ class ListenerTest(fixtures.ORMTest): f1.barset.add(b1) assert f1.barset.pop().data == 'some bar appended' + def test_named(self): + canary = Mock() + + class Foo(object): + pass + + class Bar(object): + pass + + instrumentation.register_class(Foo) + instrumentation.register_class(Bar) + attributes.register_attribute( + Foo, 'data', uselist=False, + useobject=False) + attributes.register_attribute( + Foo, 'barlist', uselist=True, + useobject=True) + + event.listen(Foo.data, 'set', canary.set, named=True) + event.listen(Foo.barlist, 'append', canary.append, named=True) + event.listen(Foo.barlist, 'remove', canary.remove, named=True) + + f1 = Foo() + b1 = Bar() + f1.data = 5 + f1.barlist.append(b1) + f1.barlist.remove(b1) + eq_( + canary.mock_calls, + [ + call.set( + oldvalue=attributes.NO_VALUE, + initiator=attributes.Event( + Foo.data.impl, attributes.OP_REPLACE), + target=f1, value=5), + call.append( + initiator=attributes.Event( + Foo.barlist.impl, attributes.OP_APPEND), + target=f1, + value=b1), + call.remove( + initiator=attributes.Event( + Foo.barlist.impl, attributes.OP_REMOVE), + target=f1, + value=b1)] + ) + def test_collection_link_events(self): class Foo(object): pass @@ -2559,9 +2606,6 @@ class ListenerTest(fixtures.ORMTest): ) - - - def test_none_on_collection_event(self): """test that append/remove of None in collections emits events. diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index f94c742b3..82331b9af 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -2191,6 +2191,23 @@ class InstrumentationTest(fixtures.ORMTest): f1.attr = l2 eq_(canary, [adapter_1, f1.attr._sa_adapter, None]) + def test_referenced_by_owner(self): + + class Foo(object): + pass + + instrumentation.register_class(Foo) + attributes.register_attribute( + Foo, 'attr', uselist=True, useobject=True) + + f1 = Foo() + f1.attr.append(3) + + adapter = collections.collection_adapter(f1.attr) + assert adapter._referenced_by_owner + + f1.attr = [] + assert not adapter._referenced_by_owner diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 214b592b5..4c6d9bbe1 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -1253,8 +1253,9 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): orders=relationship(Order, lazy=False, order_by=orders.c.id), )) q = create_session().query(User) - self.l = q.all() - eq_(self.static.user_all_result, q.order_by(User.id).all()) + def go(): + eq_(self.static.user_all_result, q.order_by(User.id).all()) + self.assert_sql_count(testing.db, go, 1) def test_against_select(self): """test eager loading of a mapper which is against a select""" diff --git a/test/orm/test_events.py b/test/orm/test_events.py index e6efd6fb9..904293102 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -112,6 +112,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): mapper(User, users) canary = self.listen_all(User) + named_canary = self.listen_all(User, named=True) sess = create_session() u = User(name='u1') @@ -125,13 +126,15 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess.flush() sess.delete(u) sess.flush() - eq_(canary, - ['init', 'before_insert', - 'after_insert', 'expire', - 'refresh', - 'load', - 'before_update', 'after_update', 'before_delete', - 'after_delete']) + expected = [ + 'init', 'before_insert', + 'after_insert', 'expire', + 'refresh', + 'load', + 'before_update', 'after_update', 'before_delete', + 'after_delete'] + eq_(canary, expected) + eq_(named_canary, expected) def test_insert_before_configured(self): users, User = self.tables.users, self.classes.User @@ -1193,6 +1196,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): 'before_commit', 'after_commit','after_transaction_end'] ) + def test_rollback_hook(self): User, users = self.classes.User, self.tables.users sess, canary = self._listener_fixture() diff --git a/test/orm/test_query.py b/test/orm/test_query.py index c9f0a5db0..f14ad7864 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -2482,6 +2482,8 @@ class YieldTest(_fixtures.FixtureTest): class HintsTest(QueryTest, AssertsCompiledSQL): + __dialect__ = 'default' + def test_hints(self): User = self.classes.User @@ -2517,6 +2519,28 @@ class HintsTest(QueryTest, AssertsCompiledSQL): "ON users_1.id > users.id", dialect=dialect ) + def test_statement_hints(self): + User = self.classes.User + + sess = create_session() + stmt = sess.query(User).\ + with_statement_hint("test hint one").\ + with_statement_hint("test hint two").\ + with_statement_hint("test hint three", "postgresql") + + self.assert_compile( + stmt, + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users test hint one test hint two", + ) + + self.assert_compile( + stmt, + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users test hint one test hint two test hint three", + dialect='postgresql' + ) + class TextTest(QueryTest, AssertsCompiledSQL): __dialect__ = 'default' diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index f0aa538f4..150b59b75 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -242,6 +242,22 @@ class _JoinFixtures(object): **kw ) + def _join_fixture_o2m_composite_selfref_func_remote_side(self, **kw): + return relationships.JoinCondition( + self.composite_selfref, + self.composite_selfref, + self.composite_selfref, + self.composite_selfref, + primaryjoin=and_( + self.composite_selfref.c.group_id == + func.foo(self.composite_selfref.c.group_id), + self.composite_selfref.c.parent_id == + self.composite_selfref.c.id + ), + remote_side=set([self.composite_selfref.c.parent_id]), + **kw + ) + def _join_fixture_o2m_composite_selfref_func_annotated(self, **kw): return relationships.JoinCondition( self.composite_selfref, @@ -729,6 +745,10 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, self._join_fixture_o2m_composite_selfref_func ) + def test_determine_local_remote_pairs_o2m_composite_selfref_func_rs(self): + # no warning + self._join_fixture_o2m_composite_selfref_func_remote_side() + def test_determine_local_remote_pairs_o2m_overlap_func_warning(self): self._assert_non_simple_warning( self._join_fixture_m2o_sub_to_joined_sub_func diff --git a/test/requirements.py b/test/requirements.py index 7eeabef2b..80bd135e9 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -297,6 +297,17 @@ class DefaultRequirements(SuiteRequirements): ) @property + def temp_table_names(self): + """target dialect supports listing of temporary table names""" + + return only_on(['sqlite', 'oracle']) + + @property + def temporary_views(self): + """target database supports temporary views""" + return only_on(['sqlite', 'postgresql']) + + @property def update_nowait(self): """Target database must support SELECT...FOR UPDATE NOWAIT""" return skip_if(["firebird", "mssql", "mysql", "sqlite", "sybase"], @@ -706,6 +717,14 @@ class DefaultRequirements(SuiteRequirements): ) @property + def postgresql_test_dblink(self): + return skip_if( + lambda config: not config.file_config.has_option( + 'sqla_testing', 'postgres_test_db_link'), + "postgres_test_db_link option not specified in config" + ) + + @property def percent_schema_names(self): return skip_if( [ diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index fc33db184..ed13e8455 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2511,6 +2511,23 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect ) + def test_statement_hints(self): + + stmt = select([table1.c.myid]).\ + with_statement_hint("test hint one").\ + with_statement_hint("test hint two", 'mysql') + + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable test hint one", + ) + + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable test hint one test hint two", + dialect='mysql' + ) + def test_literal_as_text_fromstring(self): self.assert_compile( and_(text("a"), text("b")), diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 4a484dbac..6b8e1bb40 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -16,7 +16,7 @@ from sqlalchemy import testing from sqlalchemy.testing import ComparesTables, AssertsCompiledSQL from sqlalchemy.testing import eq_, is_, mock from contextlib import contextmanager - +from sqlalchemy import util class MetaDataTest(fixtures.TestBase, ComparesTables): @@ -2124,7 +2124,7 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): assert_raises_message( exc.ArgumentError, - "Column object already assigned to Table 't'", + "Column object 'x' already assigned to Table 't'", Table, 'q', MetaData(), c) def test_incomplete_key(self): @@ -2705,7 +2705,7 @@ class DialectKWArgTest(fixtures.TestBase): lambda arg: "goofy_%s" % arg): with self._fixture(): idx = Index('a', 'b') - idx.kwargs[u'participating_x'] = 7 + idx.kwargs[util.u('participating_x')] = 7 eq_( list(idx.dialect_kwargs), @@ -1,10 +1,8 @@ [tox] -envlist = full +envlist = full,py26,py27,py33,py34 [testenv] deps=pytest - flake8 - coverage mock sitepackages=True @@ -12,7 +10,6 @@ usedevelop=True commands= python -m pytest {posargs} -envdir=pytest [testenv:full] @@ -21,22 +18,23 @@ envdir=pytest setenv= DISABLE_SQLALCHEMY_CEXT=1 +# see also .coveragerc +deps=coverage commands= - python -m pytest \ - --cov=lib/sqlalchemy \ - --exclude-tag memory-intensive \ - --exclude-tag timing-intensive \ - -k "not aaa_profiling" \ - {posargs} - python -m coverage xml --include=lib/sqlalchemy/* + python -m pytest --cov=sqlalchemy --cov-report term --cov-report xml \ + --exclude-tag memory-intensive \ + --exclude-tag timing-intensive \ + -k "not aaa_profiling" \ + {posargs} + [testenv:pep8] +deps=flake8 commands = python -m flake8 {posargs} [flake8] - show-source = True -ignore = E711,E712,E721,F841,F811 +ignore = E711,E712,E721 exclude=.venv,.git,.tox,dist,doc,*egg,build |