summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
commited4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch)
treec1cf2fb7b1cafced82a8898e23d2a0bf5ced8526
parent3a8e235af64e36b3b711df1f069d32359fe6c967 (diff)
downloadsqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
-rw-r--r--CHANGES394
-rw-r--r--README.unittests7
-rw-r--r--doc/build/content/adv_datamapping.txt180
-rw-r--r--doc/build/content/copyright.txt24
-rw-r--r--doc/build/content/datamapping.txt151
-rw-r--r--doc/build/content/docstrings.html2
-rw-r--r--doc/build/content/plugins.txt372
-rw-r--r--doc/build/content/sqlconstruction.txt2
-rw-r--r--doc/build/content/tutorial.txt74
-rw-r--r--doc/build/content/unitofwork.txt171
-rw-r--r--doc/build/gen_docstrings.py23
-rw-r--r--doc/build/genhtml.py22
-rw-r--r--doc/build/read_markdown.py8
-rw-r--r--examples/adjacencytree/basic_tree.py16
-rw-r--r--examples/adjacencytree/byroot_tree.py61
-rw-r--r--examples/backref/backref_tree.py41
-rw-r--r--examples/collections/large_collection.py3
-rw-r--r--examples/elementtree/adjacency_list.py215
-rw-r--r--examples/elementtree/optimized_al.py224
-rw-r--r--examples/elementtree/pickle.py65
-rw-r--r--examples/elementtree/test.xml9
-rw-r--r--examples/elementtree/test2.xml4
-rw-r--r--examples/elementtree/test3.xml7
-rw-r--r--examples/pickle/custom_pickler.py5
-rw-r--r--examples/poly_assoc/poly_assoc.py3
-rw-r--r--examples/poly_assoc/poly_assoc_fk.py3
-rw-r--r--examples/poly_assoc/poly_assoc_generic.py3
-rw-r--r--examples/polymorph/concrete.py3
-rw-r--r--examples/polymorph/polymorph.py7
-rw-r--r--examples/polymorph/single.py4
-rw-r--r--examples/sharding/attribute_shard.py194
-rw-r--r--examples/vertical/vertical.py15
-rw-r--r--lib/sqlalchemy/__init__.py8
-rw-r--r--lib/sqlalchemy/ansisql.py749
-rw-r--r--lib/sqlalchemy/databases/firebird.py57
-rw-r--r--lib/sqlalchemy/databases/information_schema.py13
-rw-r--r--lib/sqlalchemy/databases/informix.py63
-rw-r--r--lib/sqlalchemy/databases/mssql.py113
-rw-r--r--lib/sqlalchemy/databases/mysql.py107
-rw-r--r--lib/sqlalchemy/databases/oracle.py255
-rw-r--r--lib/sqlalchemy/databases/postgres.py293
-rw-r--r--lib/sqlalchemy/databases/sqlite.py41
-rw-r--r--lib/sqlalchemy/engine/__init__.py1
-rw-r--r--lib/sqlalchemy/engine/base.py633
-rw-r--r--lib/sqlalchemy/engine/default.py147
-rw-r--r--lib/sqlalchemy/engine/strategies.py21
-rw-r--r--lib/sqlalchemy/engine/threadlocal.py23
-rw-r--r--lib/sqlalchemy/engine/url.py8
-rw-r--r--lib/sqlalchemy/ext/activemapper.py11
-rw-r--r--lib/sqlalchemy/ext/assignmapper.py59
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py88
-rw-r--r--lib/sqlalchemy/ext/proxy.py113
-rw-r--r--lib/sqlalchemy/ext/selectresults.py218
-rw-r--r--lib/sqlalchemy/ext/sessioncontext.py28
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py13
-rw-r--r--lib/sqlalchemy/mods/legacy_session.py176
-rw-r--r--lib/sqlalchemy/mods/selectresults.py2
-rw-r--r--lib/sqlalchemy/mods/threadlocal.py53
-rw-r--r--lib/sqlalchemy/orm/__init__.py486
-rw-r--r--lib/sqlalchemy/orm/attributes.py836
-rw-r--r--lib/sqlalchemy/orm/collections.py1182
-rw-r--r--lib/sqlalchemy/orm/dependency.py6
-rw-r--r--lib/sqlalchemy/orm/interfaces.py496
-rw-r--r--lib/sqlalchemy/orm/mapper.py921
-rw-r--r--lib/sqlalchemy/orm/properties.py269
-rw-r--r--lib/sqlalchemy/orm/query.py1236
-rw-r--r--lib/sqlalchemy/orm/session.py224
-rw-r--r--lib/sqlalchemy/orm/shard.py112
-rw-r--r--lib/sqlalchemy/orm/strategies.py671
-rw-r--r--lib/sqlalchemy/orm/sync.py13
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py61
-rw-r--r--lib/sqlalchemy/orm/util.py187
-rw-r--r--lib/sqlalchemy/pool.py22
-rw-r--r--lib/sqlalchemy/schema.py375
-rw-r--r--lib/sqlalchemy/sql.py1946
-rw-r--r--lib/sqlalchemy/sql_util.py42
-rw-r--r--lib/sqlalchemy/topological.py3
-rw-r--r--lib/sqlalchemy/types.py98
-rw-r--r--lib/sqlalchemy/util.py93
-rw-r--r--setup.py2
-rw-r--r--test/base/alltests.py1
-rw-r--r--test/base/dependency.py7
-rw-r--r--test/base/utils.py67
-rw-r--r--test/dialect/alltests.py1
-rw-r--r--test/dialect/mysql.py58
-rw-r--r--test/dialect/oracle.py32
-rw-r--r--test/dialect/postgres.py128
-rw-r--r--test/engine/alltests.py2
-rw-r--r--test/engine/autoconnect_engine.py90
-rw-r--r--test/engine/bind.py49
-rw-r--r--test/engine/execute.py17
-rw-r--r--test/engine/metadata.py18
-rw-r--r--test/engine/parseconnect.py12
-rw-r--r--test/engine/pool.py44
-rw-r--r--test/engine/proxy_engine.py204
-rw-r--r--test/engine/reconnect.py12
-rw-r--r--test/engine/reflection.py130
-rw-r--r--test/engine/transaction.py294
-rw-r--r--test/ext/activemapper.py61
-rw-r--r--test/ext/alltests.py1
-rw-r--r--test/ext/assignmapper.py58
-rw-r--r--test/ext/associationproxy.py126
-rw-r--r--test/ext/legacy_objectstore.py113
-rw-r--r--test/ext/orderinglist.py49
-rw-r--r--test/ext/selectresults.py239
-rw-r--r--test/ext/wsgi_test.py122
-rw-r--r--test/orm/alltests.py23
-rw-r--r--test/orm/association.py7
-rw-r--r--test/orm/assorted_eager.py (renamed from test/orm/eagertest3.py)309
-rw-r--r--test/orm/attributes.py104
-rw-r--r--test/orm/cascade.py26
-rw-r--r--test/orm/collection.py1140
-rw-r--r--test/orm/compile.py5
-rw-r--r--test/orm/cycles.py41
-rw-r--r--test/orm/eager_relations.py133
-rw-r--r--test/orm/eagertest1.py69
-rw-r--r--test/orm/eagertest2.py239
-rw-r--r--test/orm/entity.py10
-rw-r--r--test/orm/fixtures.py4
-rw-r--r--test/orm/generative.py183
-rw-r--r--test/orm/inheritance/__init__.py0
-rw-r--r--test/orm/inheritance/abc_inheritance.py (renamed from test/orm/abc_inheritance.py)44
-rw-r--r--test/orm/inheritance/alltests.py28
-rw-r--r--test/orm/inheritance/basic.py (renamed from test/orm/inheritance.py)416
-rw-r--r--test/orm/inheritance/concrete.py (renamed from test/orm/inheritance4.py)9
-rw-r--r--test/orm/inheritance/magazine.py (renamed from test/orm/inheritance3.py)81
-rw-r--r--test/orm/inheritance/manytomany.py255
-rw-r--r--test/orm/inheritance/poly_linked_list.py (renamed from test/orm/poly_linked_list.py)35
-rw-r--r--test/orm/inheritance/polymorph.py (renamed from test/orm/polymorph.py)246
-rw-r--r--test/orm/inheritance/polymorph2.py (renamed from test/orm/inheritance5.py)142
-rw-r--r--test/orm/inheritance/productspec.py (renamed from test/orm/inheritance2.py)7
-rw-r--r--test/orm/inheritance/single.py (renamed from test/orm/single.py)7
-rw-r--r--test/orm/lazy_relations.py4
-rw-r--r--test/orm/lazytest1.py5
-rw-r--r--test/orm/manytomany.py18
-rw-r--r--test/orm/mapper.py538
-rw-r--r--test/orm/memusage.py19
-rw-r--r--test/orm/merge.py9
-rw-r--r--test/orm/onetoone.py4
-rw-r--r--test/orm/query.py564
-rw-r--r--test/orm/relationships.py141
-rw-r--r--test/orm/session.py262
-rw-r--r--test/orm/sessioncontext.py10
-rw-r--r--test/orm/sharding/__init__.py0
-rw-r--r--test/orm/sharding/alltests.py18
-rw-r--r--test/orm/sharding/shard.py154
-rw-r--r--test/orm/unitofwork.py115
-rw-r--r--test/perf/cascade_speed.py2
-rw-r--r--test/perf/masscreate.py5
-rw-r--r--test/perf/masscreate2.py6
-rw-r--r--test/perf/masseagerload.py102
-rw-r--r--test/perf/massload.py17
-rw-r--r--test/perf/massload2.py1
-rw-r--r--test/perf/masssave.py14
-rw-r--r--test/perf/ormsession.py225
-rw-r--r--test/perf/poolload.py5
-rw-r--r--test/perf/threaded_compile.py3
-rw-r--r--test/perf/wsgi.py54
-rw-r--r--test/rundocs.py242
-rw-r--r--test/sql/alltests.py4
-rw-r--r--test/sql/case_statement.py16
-rw-r--r--test/sql/constraints.py11
-rw-r--r--test/sql/defaults.py129
-rw-r--r--test/sql/generative.py275
-rw-r--r--test/sql/labels.py18
-rw-r--r--test/sql/query.py246
-rw-r--r--test/sql/quote.py5
-rw-r--r--test/sql/rowcount.py6
-rw-r--r--test/sql/select.py188
-rwxr-xr-xtest/sql/selectable.py32
-rw-r--r--test/sql/testtypes.py131
-rw-r--r--test/sql/unicode.py56
-rw-r--r--test/testbase.py474
-rw-r--r--test/testlib/__init__.py11
-rw-r--r--test/testlib/config.py255
-rw-r--r--test/testlib/coverage.py (renamed from test/coverage.py)271
-rw-r--r--test/testlib/profiling.py74
-rw-r--r--test/testlib/schema.py28
-rw-r--r--test/testlib/tables.py (renamed from test/tables.py)27
-rw-r--r--test/testlib/testing.py363
-rw-r--r--test/zblog/mappers.py1
-rw-r--r--test/zblog/tables.py7
-rw-r--r--test/zblog/tests.py28
-rw-r--r--test/zblog/user.py10
184 files changed, 15668 insertions, 9919 deletions
diff --git a/CHANGES b/CHANGES
index 9d75e2dfb..ec8d8fcce 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,227 @@
+0.4.0
+- orm
+
+ - new collection_class api and implementation [ticket:213]
+ collections are now instrumented via decorations rather than
+ proxying. you can now have collections that manage their own
+ membership, and your class instance will be directly exposed on the
+ relation property. the changes are transparent for most users.
+ - InstrumentedList (as it was) is removed, and relation properties
+ no longer have 'clear()', '.data', or any other added methods
+ beyond those provided by the collection type. you are free, of
+ course, to add them to a custom class.
+ - __setitem__-like assignments now fire remove events for the
+ existing value, if any.
+ - dict-likes used as collection classes no longer need to change
+ __iter__ semantics- itervalues() is used by default instead. this
+ is a backwards incompatible change.
+ - subclassing dict for a mapped collection is no longer needed in
+ most cases. orm.collections provides canned implementations that
+ key objects by a specified column or a custom function of your
+ choice.
+ - collection assignment now requires a compatible type- assigning
+ None to clear a collection or assigning a list to a dict
+ collection will now raise an argument error.
+ - AttributeExtension moved to interfaces, and .delete is now
+ .remove The event method signature has also been swapped around.
+
+ - major overhaul for Query: all selectXXX methods
+ are deprecated. generative methods are now the standard
+ way to do things, i.e. filter(), filter_by(), all(), one(),
+ etc. Deprecated methods are docstring'ed with their
+ new replacements.
+
+ - Class-level properties are now usable as query elements ...no
+ more '.c.' ! "Class.c.propname" is now superceded by "Class.propname".
+ All clause operators are supported, as well as higher level operators
+ such as Class.prop==<some instance> for scalar attributes,
+ Class.prop.contains(<some instance>) and Class.prop.any(<some expression>)
+ for collection-based attributes (all are also negatable). Table-based column
+ expressions as well as columns mounted on mapped classes via 'c' are of
+ course still fully available and can be freely mixed with the new attributes.
+ [ticket:643]
+
+ - removed ancient query.select_by_attributename() capability.
+
+ - the aliasing logic used by eager loading has been generalized, so that
+ it also adds full automatic aliasing support to Query. It's no longer
+ necessary to create an explicit Alias to join to the same tables multiple times;
+ *even for self-referential relationships!!*
+ - join() and outerjoin() take arguments "aliased=True". this causes
+ their joins to be built on aliased tables; subsequent calls
+ to filter() and filter_by() will translate all table expressions
+ (yes, real expressions using the original mapped Table) to be that of
+ the Alias for the duration of that join() (i.e. until reset_joinpoint()
+ or another join() is called).
+ - join() and outerjoin() take arguments "id=<somestring>". when used
+ with "aliased=True", the id can be referenced by add_entity(cls, id=<somestring>)
+ so that you can select the joined instances even if they're from an alias.
+ - join() and outerjoin() now work with self-referential relationships! using
+ "aliased=True", you can join as many levels deep as desired, i.e.
+ query.join(['children', 'children'], aliased=True); filter criterion will
+ be against the rightmost joined table
+
+ - added query.populate_existing() - marks the query to reload
+ all attributes and collections of all instances touched in the query,
+ including eagerly-loaded entities [ticket:660]
+
+ - added eagerload_all(), allows eagerload_all('x.y.z') to specify eager
+ loading of all properties in the given path
+
+ - a rudimental sharding (horizontal scaling) system is introduced. This system
+ uses a modified Session which can distribute read and write operations among
+ multiple databases, based on user-defined functions defining the
+ "sharding strategy". Instances and their dependents can be distributed
+ and queried among multiple databases based on attribute values, round-robin
+ approaches or any other user-defined system. [ticket:618]
+
+ - Eager loading has been enhanced to allow even more joins in more places.
+ It now functions at any arbitrary depth along self-referential
+ and cyclical structures. When loading cyclical structures, specify "join_depth"
+ on relation() indicating how many times you'd like the table to join
+ to itself; each level gets a distinct table alias. The alias names
+ themselves are generated at compile time using a simple counting
+ scheme now and are a lot easier on the eyes, as well as of course
+ completely deterministic. [ticket:659]
+
+ - added composite column properties. This allows you to create a
+ type which is represented by more than one column, when using the
+ ORM. Objects of the new type are fully functional in query expressions,
+ comparisons, query.get() clauses, etc. and act as though they are regular
+ single-column scalars..except they're not !
+ Use the function composite(cls, *columns) inside of the
+ mapper's "properties" dict, and instances of cls will be
+ created/mapped to a single attribute, comprised of the values
+ correponding to *columns [ticket:211]
+
+ - improved support for custom column_property() attributes which
+ feature correlated subqueries...work better with eager loading now.
+
+ - along with recent speedups to ResultProxy, total number of
+ function calls significantly reduced for large loads.
+ test/perf/masseagerload.py reports 0.4 as having the fewest number
+ of function calls across all SA versions (0.1, 0.2, and 0.3)
+
+ - primary key "collapse" behavior; the mapper will analyze all columns
+ in its given selectable for primary key "equivalence", that is,
+ columns which are equivalent via foreign key relationship or via an
+ explicit inherit_condition. primarily for joined-table inheritance
+ scenarios where different named PK columns in inheriting tables
+ should "collapse" into a single-valued (or fewer-valued) primary key.
+ fixes things like [ticket:611].
+
+ - joined-table inheritance will now generate the primary key
+ columns of all inherited classes against the root table of the
+ join only. This implies that each row in the root table is distinct
+ to a single instance. If for some rare reason this is not desireable,
+ explicit primary_key settings on individual mappers will override it.
+
+ - When "polymorphic" flags are used with joined-table or single-table
+ inheritance, all identity keys are generated against the root class
+ of the inheritance hierarchy; this allows query.get() to work
+ polymorphically using the same caching semantics as a non-polymorphic get.
+ note that this currently does not work with concrete inheritance.
+
+ - secondary inheritance loading: polymorphic mappers can be
+ constructed *without* a select_table argument. inheriting mappers
+ whose tables were not represented in the initial load will issue a
+ second SQL query immediately, once per instance (i.e. not very
+ efficient for large lists), in order to load the remaining
+ columns.
+ - secondary inheritance loading can also move its second query into
+ a column- level "deferred" load, via the "polymorphic_fetch"
+ argument, which can be set to 'select' or 'deferred'
+
+ - added undefer_group() MapperOption, sets a set of "deferred"
+ columns joined by a "group" to load as "undeferred".
+
+ - session enhancements/fixes:
+ - session can be bound to Connections
+
+ - rewrite of the "deterministic alias name" logic to be part of the
+ SQL layer, produces much simpler alias and label names more in the
+ style of Hibernate
+
+- sql
+ - all "type" keyword arguments, such as those to bindparam(), column(),
+ Column(), and func.<something>(), renamed to "type_". those objects
+ still name their "type" attribute as "type".
+ - transactions:
+ - added context manager (with statement) support for transactions
+ - added support for two phase commit, works with mysql and postgres so far.
+ - added a subtransaction implementation that uses savepoints.
+ - added support for savepoints.
+ - MetaData:
+ - DynamicMetaData has been renamed to ThreadLocalMetaData
+ - BoundMetaData has been removed- regular MetaData is equivalent
+ - Numeric and Float types now have an "asdecimal" flag; defaults to
+ True for Numeric, False for Float. when True, values are returned as
+ decimal.Decimal objects; when False, values are returned as float().
+ the defaults of True/False are already the behavior for PG and MySQL's
+ DBAPI modules. [ticket:646]
+ - new SQL operator implementation which removes all hardcoded operators
+ from expression structures and moves them into compilation;
+ allows greater flexibility of operator compilation; for example, "+"
+ compiles to "||" when used in a string context, or "concat(a,b)" on
+ MySQL; whereas in a numeric context it compiles to "+". fixes [ticket:475].
+ - "anonymous" alias and label names are now generated at SQL compilation
+ time in a completely deterministic fashion...no more random hex IDs
+ - significant architectural overhaul to SQL elements (ClauseElement).
+ all elements share a common "mutability" framework which allows a
+ consistent approach to in-place modifications of elements as well as
+ generative behavior. improves stability of the ORM which makes
+ heavy usage of mutations to SQL expressions.
+ - select() and union()'s now have "generative" behavior. methods like
+ order_by() and group_by() return a *new* instance - the original instance
+ is left unchanged. non-generative methods remain as well.
+ - the internals of select/union vastly simplified - all decision making
+ regarding "is subquery" and "correlation" pushed to SQL generation phase.
+ select() elements are now *never* mutated by their enclosing containers
+ or by any dialect's compilation process [ticket:52] [ticket:569]
+ - select(scalar=True) argument is deprecated; use select(..).as_scalar().
+ the resulting object obeys the full "column" interface and plays better
+ within expressions
+ - added select().with_prefix('foo') allowing any set of keywords to be
+ placed before the columns clause of the SELECT [ticket:504]
+ - added array slice support to row[<index>] [ticket:686]
+ - result sets make a better attempt at matching the DBAPI types present
+ in cursor.description to the TypeEngine objects defined by the dialect,
+ which are then used for result-processing. Note this only takes effect
+ for textual SQL; constructed SQL statements always have an explicit type map.
+ - result sets from CRUD operations close their underlying cursor immediately.
+ will also autoclose the connection if defined for the operation; this
+ allows more efficient usage of connections for successive CRUD operations
+ with less chance of "dangling connections".
+ - Column defaults and onupdate Python functions (i.e. passed to ColumnDefault)
+ may take zero or one arguments; the one argument is the ExecutionContext,
+ from which you can call "context.parameters[someparam]" to access the other
+ bind parameter values affixed to the statement [ticket:559]
+ - added "explcit" create/drop/execute support for sequences
+ (i.e. you can pass a "connectable" to each of those methods
+ on Sequence)
+ - better quoting of identifiers when manipulating schemas
+ - standardized the behavior for table reflection where types can't be located;
+ NullType is substituted instead, warning is raised.
+ - ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary
+ semantics for "__contains__" [ticket:606]
+
+- engines
+ - Connections gain a .properties collection, with contents scoped to the
+ lifetime of the underlying DBAPI connection
+- extensions
+ - proxyengine is temporarily removed, pending an actually working
+ replacement.
+ - SelectResults has been replaced by Query. SelectResults /
+ SelectResultsExt still exist but just return a slightly modified
+ Query object for backwards-compatibility. join_to() method
+ from SelectResults isn't present anymore, need to use join().
+- postgres
+ - Added PGArray datatype for using postgres array datatypes
+- oracle
+ - very rudimental support for OUT parameters added; use sql.outparam(name, type)
+ to set up an OUT parameter, just like bindparam(); after execution, values are
+ avaiable via result.out_parameters dictionary. [ticket:507]
+
0.3.11
- orm
- added a check for joining from A->B using join(), along two
@@ -15,20 +239,17 @@
push the pool into overflow at the same time. this issue has been
fixed.
- sql
- - better quoting of identifiers when manipulating schemas
- - got connection-bound metadata to work with implicit execution
- - foreign key specs can have any chararcter in their identifiers
- [ticket:667]
- - added commutativity-awareness to binary clause comparisons to
- each other, improves ORM lazy load optimization [ticket:664]
+ - got connection-bound metadata to work with implicit execution
+ - foreign key specs can have any chararcter in their identifiers
+ [ticket:667]
+ - added commutativity-awareness to binary clause comparisons to
+ each other, improves ORM lazy load optimization [ticket:664]
- orm
- - cleanup to connection-bound sessions, SessionTransaction
-- mysql
- - fixed issue with tables in alternate schemas [ticket:662]
+ - cleanup to connection-bound sessions, SessionTransaction
- postgres
- - fixed max identifier length (63) [ticket:571]
-
+ - fixed max identifier length (63) [ticket:571]
+
0.3.9
- general
- better error message for NoSuchColumnError [ticket:607]
@@ -73,9 +294,6 @@
- small fix to eager loading to better work with eager loads
to polymorphic mappers that are using a straight "outerjoin"
clause
- - fix to the "column_prefix" flag so that the mapper does not
- trip over synonyms (and others) that are named after the column's actual
- "key" (since, column_prefix means "dont use the key").
- sql
- ForeignKey to a table in a schema thats not the default schema
requires the schema to be explicit; i.e. ForeignKey('alt_schema.users.id')
@@ -166,85 +384,85 @@
0.3.8
- engines
- - added detach() to Connection, allows underlying DBAPI connection
- to be detached from its pool, closing on dereference/close()
- instead of being reused by the pool.
- - added invalidate() to Connection, immediately invalidates the
- Connection and its underlying DBAPI connection.
+ - added detach() to Connection, allows underlying DBAPI connection
+ to be detached from its pool, closing on dereference/close()
+ instead of being reused by the pool.
+ - added invalidate() to Connection, immediately invalidates the
+ Connection and its underlying DBAPI connection.
- sql
- - _Label class overrides compare_self to return its ultimate
- object. meaning, if you say someexpr.label('foo') == 5, it
- produces the correct "someexpr == 5".
- - _Label propigates "_hide_froms()" so that scalar selects
- behave more properly with regards to FROM clause #574
- - fix to long name generation when using oid_column as an order by
- (oids used heavily in mapper queries)
- - significant speed improvement to ResultProxy, pre-caches
- TypeEngine dialect implementations and saves on function calls
- per column
- - parenthesis are applied to clauses via a new _Grouping
- construct. uses operator precedence to more intelligently apply
- parenthesis to clauses, provides cleaner nesting of clauses
- (doesnt mutate clauses placed in other clauses, i.e. no 'parens'
- flag)
- - added 'modifier' keyword, works like func.<foo> except does not
- add parenthesis. e.g. select([modifier.DISTINCT(...)]) etc.
- - removed "no group by's in a select thats part of a UNION"
- restriction [ticket:578]
+ - _Label class overrides compare_self to return its ultimate
+ object. meaning, if you say someexpr.label('foo') == 5, it
+ produces the correct "someexpr == 5".
+ - _Label propigates "_hide_froms()" so that scalar selects
+ behave more properly with regards to FROM clause #574
+ - fix to long name generation when using oid_column as an order by
+ (oids used heavily in mapper queries)
+ - significant speed improvement to ResultProxy, pre-caches
+ TypeEngine dialect implementations and saves on function calls
+ per column
+ - parenthesis are applied to clauses via a new _Grouping
+ construct. uses operator precedence to more intelligently apply
+ parenthesis to clauses, provides cleaner nesting of clauses
+ (doesnt mutate clauses placed in other clauses, i.e. no 'parens'
+ flag)
+ - added 'modifier' keyword, works like func.<foo> except does not
+ add parenthesis. e.g. select([modifier.DISTINCT(...)]) etc.
+ - removed "no group by's in a select thats part of a UNION"
+ restriction [ticket:578]
- orm
- - added reset_joinpoint() method to Query, moves the "join point"
- back to the starting mapper. 0.4 will change the behavior of
- join() to reset the "join point" in all cases so this is an
- interim method. for forwards compatibility, ensure joins across
- multiple relations are specified using a single join(), i.e.
- join(['a', 'b', 'c']).
- - fixed bug in query.instances() that wouldnt handle more than
- on additional mapper or one additional column.
- - "delete-orphan" no longer implies "delete". ongoing effort to
- separate the behavior of these two operations.
- - many-to-many relationships properly set the type of bind params
- for delete operations on the association table
- - many-to-many relationships check that the number of rows deleted
- from the association table by a delete operation matches the
- expected results
- - session.get() and session.load() propigate **kwargs through to
- query
- - fix to polymorphic query which allows the original
- polymorphic_union to be embedded into a correlated subquery
- [ticket:577]
- - fix to select_by(<propname>=<object instance>) -style joins in
- conjunction with many-to-many relationships, bug introduced in
- r2556
- - the "primary_key" argument to mapper() is propigated to the
- "polymorphic" mapper. primary key columns in this list get
- normalized to that of the mapper's local table.
- - restored logging of "lazy loading clause" under
- sa.orm.strategies logger, got removed in 0.3.7
- - improved support for eagerloading of properties off of mappers
- that are mapped to select() statements; i.e. eagerloader is
- better at locating the correct selectable with which to attach
- its LEFT OUTER JOIN.
+ - added reset_joinpoint() method to Query, moves the "join point"
+ back to the starting mapper. 0.4 will change the behavior of
+ join() to reset the "join point" in all cases so this is an
+ interim method. for forwards compatibility, ensure joins across
+ multiple relations are specified using a single join(), i.e.
+ join(['a', 'b', 'c']).
+ - fixed bug in query.instances() that wouldnt handle more than
+ on additional mapper or one additional column.
+ - "delete-orphan" no longer implies "delete". ongoing effort to
+ separate the behavior of these two operations.
+ - many-to-many relationships properly set the type of bind params
+ for delete operations on the association table
+ - many-to-many relationships check that the number of rows deleted
+ from the association table by a delete operation matches the
+ expected results
+ - session.get() and session.load() propigate **kwargs through to
+ query
+ - fix to polymorphic query which allows the original
+ polymorphic_union to be embedded into a correlated subquery
+ [ticket:577]
+ - fix to select_by(<propname>=<object instance>) -style joins in
+ conjunction with many-to-many relationships, bug introduced in
+ r2556
+ - the "primary_key" argument to mapper() is propigated to the
+ "polymorphic" mapper. primary key columns in this list get
+ normalized to that of the mapper's local table.
+ - restored logging of "lazy loading clause" under
+ sa.orm.strategies logger, got removed in 0.3.7
+ - improved support for eagerloading of properties off of mappers
+ that are mapped to select() statements; i.e. eagerloader is
+ better at locating the correct selectable with which to attach
+ its LEFT OUTER JOIN.
- mysql
- - Nearly all MySQL column types are now supported for declaration
- and reflection. Added NCHAR, NVARCHAR, VARBINARY, TINYBLOB,
- LONGBLOB, YEAR
- - The sqltypes.Binary passthrough now always builds a BLOB,
- avoiding problems with very old database versions
- - support for column-level CHARACTER SET and COLLATE declarations,
- as well as ASCII, UNICODE, NATIONAL and BINARY shorthand.
+ - Nearly all MySQL column types are now supported for declaration
+ and reflection. Added NCHAR, NVARCHAR, VARBINARY, TINYBLOB,
+ LONGBLOB, YEAR
+ - The sqltypes.Binary passthrough now always builds a BLOB,
+ avoiding problems with very old database versions
+ - support for column-level CHARACTER SET and COLLATE declarations,
+ as well as ASCII, UNICODE, NATIONAL and BINARY shorthand.
- firebird
- - set max identifier length to 31
- - supports_sane_rowcount() set to False due to ticket #370.
- versioned_id_col feature wont work in FB.
- - some execution fixes
+ - set max identifier length to 31
+ - supports_sane_rowcount() set to False due to ticket #370.
+ versioned_id_col feature wont work in FB.
+ - some execution fixes
-extensions
- - new association proxy implementation, implementing complete
- proxies to list, dict and set-based relation collections
- - added orderinglist, a custom list class that synchronizes an
- object attribute with that object's position in the list
- - small fix to SelectResultsExt to not bypass itself during
- select().
- - added filter(), filter_by() to assignmapper
+ - new association proxy implementation, implementing complete
+ proxies to list, dict and set-based relation collections
+ - added orderinglist, a custom list class that synchronizes an
+ object attribute with that object's position in the list
+ - small fix to SelectResultsExt to not bypass itself during
+ select().
+ - added filter(), filter_by() to assignmapper
0.3.7
- engines
diff --git a/README.unittests b/README.unittests
index 729cd42a5..a4a9b5197 100644
--- a/README.unittests
+++ b/README.unittests
@@ -54,6 +54,13 @@ Help is available via:
default)
--coverage Dump a full coverage report after running
+NON-SQLITE DATABASES
+--------------------
+The prefab database connections expect to log in to localhost on the
+default port as user "scott", password "tiger", database "test" (where
+applicable). E.g. for postgresql the this translates to
+"postgres://scott:tiger@127.0.0.1:5432/test".
+
RUNNING INDIVIDUAL TESTS
-------------------------
Any unittest module can be run directly from the module file (same commandline options):
diff --git a/doc/build/content/adv_datamapping.txt b/doc/build/content/adv_datamapping.txt
index 07815a583..033fb3359 100644
--- a/doc/build/content/adv_datamapping.txt
+++ b/doc/build/content/adv_datamapping.txt
@@ -113,44 +113,168 @@ Synonym can be established with the flag "proxy=True", to create a class-level p
>>> x._email
'john@doe.com'
-#### Custom List Classes {@name=customlist}
+#### Entity Collections {@name=entitycollections}
-Feature Status: [Alpha API][alpha_api]
+Mapping a one-to-many or many-to-many relationship results in a collection of values accessible through an attribute on the parent instance. By default, this collection is a `list`:
-A one-to-many or many-to-many relationship results in a list-holding element being attached to all instances of a class. The actual list is an "instrumented" list, which transparently maintains a relationship to a plain Python list. The implementation of the underlying plain list can be changed to be any object that implements a `list`-style `append` and `__iter__` method. A common need is for a list-based relationship to actually be a dictionary. This can be achieved by subclassing `dict` to have `list`-like behavior.
+ {python}
+ mapper(Parent, properties={
+ children = relation(Child)
+ })
-In this example, a class `MyClass` is defined, which is associated with a parent object `MyParent`. The collection of `MyClass` objects on each `MyParent` object will be a dictionary, storing each `MyClass` instance keyed to its `name` attribute.
+ parent = Parent()
+ parent.children.append(Child())
+ print parent.children[0]
+
+Collections are not limited to lists. Sets, mutable sequences and almost any other Python object that can act as a container can be used in place of the default list.
{python}
- # a class to be stored in the list
- class MyClass(object):
- def __init__(self, name):
- self.name = name
-
- # create a dictionary that will act like a list, and store
- # instances of MyClass
- class MyDict(dict):
+ # use a set
+ mapper(Parent, properties={
+ children = relation(Child, collection_class=set)
+ })
+
+ parent = Parent()
+ child = Child()
+ parent.children.add(child)
+ assert child in parent.children
+
+##### Custom Entity Collections {@name=customcollections}
+
+You can use your own types for collections as well. For most cases, simply inherit from `list` or `set` and add the custom behavior.
+
+Collections in SQLAlchemy are transparently *instrumented*. Instrumentation means that normal operations on the collection are tracked and result in changes being written to the database at flush time. Additionally, collection operations can fire *events* which indicate some secondary operation must take place. Examples of a secondary operation include saving the child item in the parent's `Session` (i.e. the `save-update` cascade), as well as synchronizing the state of a bi-directional relationship (i.e. a `backref`).
+
+The collections package understands the basic interface of lists, sets and dicts and will automatically apply instrumentation to those built-in types and their subclasses. Object-derived types that implement a basic collection interface are detected and instrumented via duck-typing:
+
+ {python}
+ class ListLike(object):
+ def __init__(self):
+ self.data = []
def append(self, item):
- self[item.name] = item
+ self.data.append(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def extend(self, items):
+ self.data.extend(items)
def __iter__(self):
- return self.values()
+ return iter(self.data)
+ def foo(self):
+ return 'foo'
- # parent class
- class MyParent(object):
- pass
-
- # mappers, constructed normally
- mapper(MyClass, myclass_table)
- mapper(MyParent, myparent_table, properties={
- 'myclasses' : relation(MyClass, collection_class=MyDict)
+`append`, `remove`, and `extend` are known list-like methods, and will be instrumented automatically. `__iter__` is not a mutator method and won't be instrumented, and `foo` won't be either.
+
+Duck-typing (i.e. guesswork) isn't rock-solid, of course, so you can be explicit about the interface you are implementing by providing an `__emulates__` class attribute:
+
+ {python}
+ class SetLike(object):
+ __emulates__ = set
+
+ def __init__(self):
+ self.data = set()
+ def append(self, item):
+ self.data.add(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def __iter__(self):
+ return iter(self.data)
+
+This class looks list-like because of `append`, but `__emulates__` forces it to set-like. `remove` is known to be part of the set interface and will be instrumented.
+
+But this class won't work quite yet: a little glue is needed to adapt it for use by SQLAlchemy. The ORM needs to know which methods to use to append, remove and iterate over members of the collection. When using a type like `list` or `set`, the appropriate methods are well-known and used automatically when present. This set-like class does not provide the expected `add` method, so we must supply an explicit mapping for the ORM via a decorator.
+
+##### Collection Decorators {@name=collectiondecorators}
+
+Decorators can be used to tag the individual methods the ORM needs to manage collections. Use them when your class doesn't quite meet the regular interface for its container type, or you simply would like to use a different method to get the job done.
+
+ {python}
+ from sqlalchemy.orm.collections import collection
+
+ class SetLike(object):
+ __emulates__ = set
+
+ def __init__(self):
+ self.data = set()
+
+ @collection.appender
+ def append(self, item):
+ self.data.add(item)
+
+ def remove(self, item):
+ self.data.remove(item)
+
+ def __iter__(self):
+ return iter(self.data)
+
+And that's all that's needed to complete the example. SQLAlchemy will add instances via the `append` method. `remove` and `__iter__` are the default methods for sets and will be used for removing and iteration. Default methods can be changed as well:
+
+ {python}
+ from sqlalchemy.orm.collections import collection
+
+ class MyList(list):
+ @collection.remover
+ def zark(self, item):
+ # do something special...
+
+ @collection.iterator
+ def hey_use_this_instead_for_iteration(self):
+ # ...
+
+There is no requirement to be list-, or set-like at all. Collection classes can be any shape, so long as they have the append, remove and iterate interface marked for SQLAlchemy's use. Append and remove methods will be called with a mapped entity as the single argument, and iterator methods are called with no arguments and must return an iterator.
+
+##### Dictionary-Based Collections {@name=dictcollections}
+
+A `dict` can be used as a collection, but a keying strategy is needed to map entities loaded by the ORM to key, value pairs. The [collections](rel:docstrings_sqlalchemy.orm.collections) package provides several built-in types for dictionary-based collections:
+
+ {python}
+ from sqlalchemy.orm.collections import column_mapped_collection, attr_mapped_collection, mapped_collection
+
+ mapper(Item, items_table, properties={
+ # key by column
+ notes = relation(Note, collection_class=column_mapped_collection(notes_table.c.keyword))
+ # or named attribute
+ notes2 = relation(Note, collection_class=attr_mapped_collection('keyword'))
+ # or any callable
+ notes3 = relation(Note, collection_class=mapped_collection(lambda entity: entity.a + entity.b))
})
-
- # elements on 'myclasses' can be accessed via string keyname
- myparent = MyParent()
- myparent.myclasses.append(MyClass('this is myclass'))
- myclass = myparent.myclasses['this is myclass']
-Note: SQLAlchemy 0.4 has an overhauled and much improved implementation for custom list classes, with some slight API changes.
+ # ...
+ item = Item()
+ item.notes['color'] = Note('color', 'blue')
+ print item.notes['color']
+
+These functions each provide a `dict` subclass with decorated `set` and `remove` methods and the keying strategy of your choice.
+
+The [collections.MappedCollection](rel:docstrings_sqlalchemy.orm.collections.MappedCollection) class can be used as a base class for your custom types or as a mix-in to quickly add `dict` collection support to other classes. It uses a keying function to delegate to `__setitem__` and `__delitem__`:
+
+ {python}
+ from sqlalchemy.util import OrderedDict
+ from sqlalchemy.orm.collections import MappedCollection
+
+ class NodeMap(OrderedDict, MappedCollection):
+ """Holds 'Node' objects, keyed by the 'name' attribute with insert order maintained."""
+
+ def __init__(self, *args, **kw):
+ MappedCollection.__init__(self, keyfunc=lambda node: node.name)
+ OrderedDict.__init__(self, *args, **kw)
+
+The ORM understands the `dict` interface just like lists and sets, and will automatically instrument all dict-like methods if you choose to subclass `dict` or provide dict-like collection behavior in a duck-typed class. You must decorate appender and remover methods, however- there are no compatible methods in the basic dictionary interface for SQLAlchemy to use by default. Iteration will go through `itervalues()` unless otherwise decorated.
+
+##### Instrumentation and Custom Types {@name=adv_collections}
+
+Many custom types and existing library classes can be used as a entity collection type as-is without further ado. However, it is important to note that the instrumentation process _will_ modify the type, adding decorators around methods automatically.
+
+The decorations are lightweight and no-op outside of relations, but they do add unneeded overhead when triggered elsewhere. When using a library class as a collection, it can be good practice to use the "trivial subclass" trick to restrict the decorations to just your usage in relations. For example:
+
+ {python}
+ class MyAwesomeList(some.great.library.AwesomeList):
+ pass
+
+ # ... relation(..., collection_class=MyAwesomeList)
+
+The ORM uses this approach for built-ins, quietly substituting a trivial subclass when a `list`, `set` or `dict` is used directly.
+
+The collections package provides additional decorators and support for authoring custom types. See the [package documentation](rel:docstrings_sqlalchemy.orm.collections) for more information and discussion of advanced usage and Python 2.3-compatible decoration options.
#### Custom Join Conditions {@name=customjoin}
diff --git a/doc/build/content/copyright.txt b/doc/build/content/copyright.txt
new file mode 100644
index 000000000..bc76e9f64
--- /dev/null
+++ b/doc/build/content/copyright.txt
@@ -0,0 +1,24 @@
+Appendix: Copyright {@name=copyright}
+================
+
+This is the MIT license: http://www.opensource.org/licenses/mit-license.php
+
+Copyright (c) 2005, 2006, 2007 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
+Bayer.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this
+software and associated documentation files (the "Software"), to deal in the Software
+without restriction, including without limitation the rights to use, copy, modify, merge,
+publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
+to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or
+substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+
diff --git a/doc/build/content/datamapping.txt b/doc/build/content/datamapping.txt
index 13dff70d3..d35fecfb3 100644
--- a/doc/build/content/datamapping.txt
+++ b/doc/build/content/datamapping.txt
@@ -6,13 +6,13 @@ Data Mapping {@name=datamapping}
### Basic Data Mapping {@name=datamapping}
-Data mapping describes the process of defining `Mapper` objects, which associate `Table` objects with user-defined classes.
+Data mapping describes the process of defining `Mapper` objects, which associate table metadata with user-defined classes.
-When a `Mapper` is created to associate a `Table` object with a class, all of the columns defined in the `Table` object are associated with the class via property accessors, which add overriding functionality to the normal process of setting and getting object attributes. These property accessors keep track of changes to object attributes, so that they may be stored to the database when the application "flushes" the current state of objects. This pattern is called a *Unit of Work* pattern.
+When a `Mapper` is created to associate a `Table` object with a class, all of the columns defined in the `Table` object are associated with the class via property accessors, which add overriding functionality to the normal process of setting and getting object attributes. These property accessors keep track of changes to object attributes; these changes will be stored to the database when the application "flushes" the current state of objects. This pattern is called a *Unit of Work* pattern.
### Synopsis {@name=synopsis}
-Starting with a `Table` definition and a minimal class construct, the two are associated with each other via the `mapper()` function [[api](rel:docstrings_sqlalchemy.orm.mapper_Mapper)], which generates an object called a `Mapper`. SA associates the class and all instances of that class with this particular `Mapper`, which is stored in a registry such that SQLAlchemy knows how to find it automatically.
+Starting with a `Table` definition and a minimal class construct, the two are associated with each other via the `mapper()` function [[api](rel:docstrings_sqlalchemy.orm.mapper_Mapper)], which generates an object called a `Mapper`. SA associates the class and all instances of that class with this particular `Mapper`, which is then stored in a global registry.
{python}
from sqlalchemy import *
@@ -48,7 +48,7 @@ The `session` represents a "workspace" which can load objects and persist change
{python}
# select
- {sql}user = session.query(User).filter_by(user_name='fred')[0]
+ {sql}user = session.query(User).filter_by(user_name='fred').first()
SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
users.fullname AS users_fullname, users.password AS users_password
FROM users
@@ -67,7 +67,7 @@ The `session` represents a "workspace" which can load objects and persist change
[{'user_name': 'fred jones', 'user_id': 1}]
COMMIT
-Things to note from the above include that the loaded `User` object has an attribute named `user_name` on it, which corresponds to the `user_name` column in `users_table`; this attribute was configured at the class level by the `Mapper`, as part of it's post-initialization process (this process occurs normally when the mapper is first used). Our modify operation on this attribute caused the object to be marked as "dirty", which was picked up automatically within the subsequent `flush()` process. The `flush()` is the point at which all changes to objects within the `Session` are persisted to the database; afterwards, the `User` object is no longer marked as "dirty" until it is again modified.
+Things to note from the above include that the loaded `User` object has an attribute named `user_name` on it, which corresponds to the `user_name` column in `users_table`; this attribute was configured at the class level by the `Mapper`, as part of it's post-initialization process (this process occurs normally when the mapper is first used). Our modify operation on this attribute caused the object to be marked as "dirty", which was picked up automatically within the subsequent `flush()` process. The `flush()` is the point at which all changes to objects within the `Session` are persisted to the database, and the `User` object is no longer marked as "dirty" until it is again modified.
### The Query Object {@name=query}
@@ -93,22 +93,12 @@ A query which joins across multiple tables may also be used to request multiple
{python}
query = session.query(User, Address)
-Once we have a query, we can start loading objects. The Query object, when first created, represents all the instances of its main class. You can iterate through it directly:
+Once we have a query, we can start loading objects. The methods `filter()` and `filter_by()` handle narrowing results, and the methods `all()`, `one()`, and `first()` exist to return all, exactly one, or the first result of the total set of results. Note that all methods are *generative*, meaning that on each call that doesn't return results, you get a **new** `Query` instance.
- {python}
- {sql}for user in session.query(User):
- print user.name
- SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
- users.fullname AS users_fullname, users.password AS users_password
- FROM users
- ORDER BY users.oid
- {}
-
-...and the SQL will be issued at the point where the query is evaluated as a list. To narrow results, the two main methods are `filter()` and `filter_by()`. `filter_by()` uses keyword arguments, which translate into equality clauses joined together via 'AND':
+The `filter_by()` method works with keyword arguments, which are combined together via AND:
{python}
- {sql}for user in session.query(User).filter_by(name='john', fullname='John Smith'):
- print user.name
+ {sql}result = session.query(User).filter_by(name='john', fullname='John Smith').all()
SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
users.fullname AS users_fullname, users.password AS users_password
FROM users
@@ -116,11 +106,10 @@ Once we have a query, we can start loading objects. The Query object, when firs
ORDER BY users.oid
{'users_user_name': 'john', 'users_fullname': 'John Smith'}
-`filter()`, on the other hand, works with constructed SQL expressions, like those described in [sql](rel:sql):
+Whereas `filter()` works with constructed SQL expressions, i.e. those described in [sql](rel:sql):
{python}
- {sql}for user in session.query(User).filter(users_table.c.name=='john'):
- print user.name
+ {sql}result = session.query(User).filter(users_table.c.name=='john').all()
SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
users.fullname AS users_fullname, users.password AS users_password
FROM users
@@ -128,44 +117,29 @@ Once we have a query, we can start loading objects. The Query object, when firs
ORDER BY users.oid
{'users_user_name': 'john'}
-Evaluating the query using an array slice returns a new Query which will apply LIMIT and OFFSET clauses when iterated:
+Sometimes, constructing SQL via expressions can be cumbersome. For quick SQL expression, the `filter()` method can also accomodate straight text:
{python}
- {sql}for u in session.query(User)[1:3]:
- print u
+ {sql}result = session.query(User).filter("user_id>224").all()
SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
users.fullname AS users_fullname, users.password AS users_password
- FROM users ORDER BY users.oid
- LIMIT 2 OFFSET 1
+ FROM users
+ WHERE users.user_id>224
+ ORDER BY users.oid
{}
-A single array index adds LIMIT/OFFSET and returns a result immediately:
+When using text, bind parameters can be specified the same way as in a `text()` clause, using a colon. To specify the bind parameter values, use the `params()` method:
{python}
- {sql}user = session.query(User).filter(user_table.c.name=='john')[2]
+ {sql}result = session.query(User).filter("user_id>:value").params(value=224).all()
SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
users.fullname AS users_fullname, users.password AS users_password
FROM users
- WHERE users.user_name = :users_user_name
+ WHERE users.user_id>:value
ORDER BY users.oid
- LIMIT 1 OFFSET 2
- {'users_user_name': 'john'}
+ {'value': 224}
-There are also methods which will immediately issue the SQL represented by a `Query` without using an iterative context or array index; these methods are `all()`, `one()`, and `first()`. `all()` returns a list of all instances, `one()` returns exactly one instance as a scalar, and `first()` returns the first instance also as a scalar:
-
- {python}
- query = session.query(User).filter(users_table.c.name=='john')
-
- # get all results into a list
- allusers = query.all()
-
- # get the first user
- user = query.first()
-
- # get exactly one user; raises an exception if not exactly one result is returned
- user = query.one()
-
-Note that most methods on `Query` are *generative*, in that they return a new `Query` instance that is a modified version of the previous one. It's only when you evaluate the query in an iterative context, use an array index, or call `all()`, `first()`, or `one()` (as well as some other methods we'll cover later), that SQL is issued. For example, you can issue `filter()` or `filter_by()` as many times as needed; the various criteria are joined together using `AND`:
+Multiple `filter()` and `filter_by()` expressions may be combined together. The resulting statement groups them using AND.
{python}
result = session.query(User).filter(users_table.c.user_id>224).filter_by(name='john').
@@ -178,43 +152,23 @@ Note that most methods on `Query` are *generative*, in that they return a new `Q
ORDER BY users.oid
{'users_user_name': 'john', 'users_fullname': 'John Smith', 'users_user_id': 224}
-If you need to use other conjunctions besides `AND`, all SQL conjunctions are available explicitly within expressions, such as `and_()` and `or_()`, when using `filter()`:
+`filter_by()`'s keyword arguments can also take mapped object instances as comparison arguments. We'll illustrate this later when we talk about object relationships.
+
+Note that all conjunctions are available explicitly, such as `and_()` and `or_()`, when using `filter()`:
{python}
result = session.query(User).filter(
and_(users_table.c.user_id>224, or_(users_table.c.name=='john', users_table.c.name=='ed'))
).all()
-Sometimes, constructing criterion via expressions can be cumbersome. For a quick, string-based expression, the `filter()` method can also accomodate straight text:
-
- {python}
- {sql}result = session.query(User).filter("user_id>224").all()
- SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
- users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.user_id>224
- ORDER BY users.oid
- {}
-
-When using text, bind parameters can be specified the same way as in a `text()` clause, using a colon. To specify the bind parameter values, use the `params()` method:
-
- {python}
- {sql}result = session.query(User).filter("user_id>:value and user_name=:name").params(value=224, name='jack').all()
- SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
- users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.user_id>:value and user_name=:name
- ORDER BY users.oid
- {'value': 224, 'name': 'jack'}
-
-It's also straightforward to use an entirely string-based statement, using `from_statement()`; just ensure that the columns clause of the statement contains the column names normally used by the mapper (below illustrated using an asterisk):
+Its also straightforward to use an entirely string-based statement, using `from_statement()`; just ensure that the columns clause of the statement contains the column names normally used by the mapper (here illustrated using an asterisk):
{python}
{sql}result = session.query(User).from_statement("SELECT * FROM users").all()
SELECT * FROM users
{}
-`from_statement()` can also accomodate full `select()` constructs:
+`from_statement()` can also accomodate `select()` constructs:
{python}
result = session.query(User).from_statement(
@@ -228,14 +182,14 @@ It's also straightforward to use an entirely string-based statement, using `from
ORDER BY users.oid
{'users_user_name': 'e'}
-The current criterion represented by a `Query` can be distilled into a count of rows using `count()`. This is another function which executes SQL immediately, returning an integer result:
+Any set of filtered criterion (or no criterion) can be distilled into a count of rows using `count()`:
{python}
{sql}num = session.query(Users).filter(users_table.c.user_id>224).count()
SELECT count(users.id) FROM users WHERE users.user_id>:users_user_id
{'users_user_id': 224}
-To add limit and offset values explicitly at any time, you can use `limit()` and `offset()`:
+Rows are limited and offset using `limit()` and `offset()`:
{python}
{sql}result = session.query(User).limit(20).offset(5).all()
@@ -245,25 +199,62 @@ To add limit and offset values explicitly at any time, you can use `limit()` and
LIMIT 20 OFFSET 5
{}
-Ordering is applied, using `Column` objects and related SQL constructs, with `order_by()`:
+And ordering is applied, using `Column` objects and related SQL constructs, with `order_by()`:
{python}
- query = session.query(User).order_by(desc(users_table.c.user_name))
- {sql}for user in query:
- print user
+ {sql}result = session.query(User).order_by(desc(users_table.c.user_name)).all()
SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
users.fullname AS users_fullname, users.password AS users_password
FROM users ORDER BY users.user_name DESC
{}
-There's also a way to combine scalar results with objects, using `add_column()`. This is often used for functions and aggregates. When `add_column()` (or its cousin `add_entity()`, described later) is used, tuples are returned:
+The `first()` and `one()` methods will also limit rows, and both will return a single object, instead of a list. In the case of `first()`, rows are limited to just one, and the result is returned as a scalar. In the case of `one()`, rows are limited to *two*; however, only one is returned. If two rows are matched, an exception is raised.
+
+ {python}
+ # load the first result
+ user = session.query(User).first()
+
+ # load exactly *one* result - if more than one result matches, an exception is raised
+ user = session.query(User).filter_by(name='jack').one()
+
+The `Query`, when evaluated as an iterator, executes results immediately, using whatever state has been built up:
+
+ {python}
+ {sql}result = list(session.query(User))
+ SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
+ users.fullname AS users_fullname, users.password AS users_password
+ FROM users ORDER BY users.oid
+ {}
+
+Array indexes and slices work too, adding the corresponding LIMIT and OFFSET clauses:
{python}
- for r in session.query(User).add_column(func.max(users_table.c.name)).group_by([c for c in users_table.c]):
+ {sql}result = list(session.query(User)[1:3])
+ SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
+ users.fullname AS users_fullname, users.password AS users_password
+ FROM users ORDER BY users.oid
+ LIMIT 2 OFFSET 1
+ {}
+
+A scalar index returns a scalar result immediately:
+
+ {python}
+ {sql}user = session.query(User)[2]
+ SELECT users.user_id AS users_user_id, users.user_name AS users_user_name,
+ users.fullname AS users_fullname, users.password AS users_password
+ FROM users ORDER BY users.oid
+ LIMIT 1 OFFSET 2
+ {}
+
+Theres also a way to combine scalar results with objects, using `add_column()`. This is often used for functions and aggregates. When `add_column()` (or its cousin `add_entity()`, described later) is used, tuples are returned:
+
+ {python}
+ result = session.query(User).add_column(func.max(users_table.c.name)).group_by([c for c in users_table.c]).all()
+ for r in result:
print "user:", r[0]
print "max name:", r[1]
-Later in this chapter, we'll discuss how to configure relations between mapped classes. Once that's done, we'll discuss how to return multiple objects at once, as well as how to join, in [datamapping_joins](rel:datamapping_joins).
+Later in this chapter, we'll discuss how to configure relations between mapped classes. Once that's done, we'll discuss how to use table joins in [datamapping_joins](rel:datamapping_joins).
#### Loading by Primary Key {@name=primarykey}
diff --git a/doc/build/content/docstrings.html b/doc/build/content/docstrings.html
index c0a9e1ac2..0f125cecf 100644
--- a/doc/build/content/docstrings.html
+++ b/doc/build/content/docstrings.html
@@ -4,7 +4,7 @@
<%namespace name="formatting" file="formatting.html"/>
<%namespace name="nav" file="nav.html"/>
<%namespace name="pydoc" file="pydoc.html"/>
-<%def name="title()">SQLAlchemy 0.3 Documentation - Modules and Classes</%def>
+<%def name="title()">SQLAlchemy 0.4 Documentation - Modules and Classes</%def>
<%!
filename = 'docstrings'
diff --git a/doc/build/content/plugins.txt b/doc/build/content/plugins.txt
index b7fc74fdb..dbc85a6f9 100644
--- a/doc/build/content/plugins.txt
+++ b/doc/build/content/plugins.txt
@@ -281,84 +281,249 @@ To continue the `MyClass` example:
**Author:** Mike Bayer and Jason Kirtland<br/>
**Version:** 0.3.1 or greater
-`associationproxy` is used to create a transparent proxy to the associated object in an association relationship, thereby decreasing the verbosity of the pattern in cases where explicit access to the association object is not required. The association relationship pattern is a richer form of a many-to-many relationship, which is described in [datamapping_association](rel:datamapping_association). It is strongly recommended to fully understand the association object pattern in its explicit form before using this extension; see the examples in the SQLAlchemy distribution under the directory `examples/association/`.
+`associationproxy` is used to create a simplified, read/write view of a relationship. It can be used to cherry-pick fields from a collection of related objects or to greatly simplify access to associated objects in an association relationship.
-When dealing with association relationships, the **association object** refers to the object that maps to a row in the association table (i.e. the many-to-many table), while the **associated object** refers to the "endpoint" of the association, i.e. the ultimate object referenced by the parent. The proxy can return collections of objects attached to association objects, and can also create new association objects given only the associated object. An example using the Keyword mapping described in the data mapping documentation is as follows:
+#### Simplifying Relations
{python}
- from sqlalchemy.ext.associationproxy import association_proxy
+ users_table = Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(64)),
+ )
+ keywords_table = Table('keywords', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('keyword', String(64))
+ )
+
+ userkeywords_table = Table('userkeywords', metadata,
+ Column('user_id', Integer, ForeignKey("users.id"),
+ primary_key=True),
+ Column('keyword_id', Integer, ForeignKey("keywords.id"),
+ primary_key=True)
+ )
+
class User(object):
- pass
+ def __init__(self, name):
+ self.name = name
class Keyword(object):
+ def __init__(self, keyword):
+ self.keyword = keyword
+
+ mapper(User, users, properties={
+ 'kw': relation(Keyword, secondary=userkeywords)
+ })
+ mapper(Keyword, keywords)
+
+Above are three simple tables, modeling users, keywords and a many-to-many relationship between the two. These ``Keyword`` objects are little more than a container for a name, and accessing them via the relation is awkward:
+
+ {python}
+ user = User('jek')
+ user.kw.append(Keyword('cheese inspector'))
+ print user.kw
+ # [<__main__.Keyword object at 0xb791ea0c>]
+ print user.kw[0].keyword
+ # 'cheese inspector'
+ print [keyword.keyword for keyword in u._keywords]
+ # ['cheese inspector']
+
+With ``association_proxy`` you have a "view" of the relation that contains just the `.keyword` of the related objects. The proxy is a Python property, and unlike the mapper relation, is defined in your class:
+
+ {python}
+ from sqlalchemy.ext.associationproxy import association_proxy
+
+ class User(object):
def __init__(self, name):
- self.keyword_name = name
+ self.name = name
- class Article(object):
- # create "keywords" proxied association.
- # the collection is called 'keyword_associations', the endpoint
- # attribute of each association object is called 'keyword'. the
- # class itself of the association object will be figured out automatically .
- keywords = association_proxy('keyword_associations', 'keyword')
+ # proxy the 'keyword' attribute from the 'kw' relation
+ keywords = association_proxy('kw', 'keyword')
- class KeywordAssociation(object):
- pass
+ # ...
+ >>> user.kw
+ [<__main__.Keyword object at 0xb791ea0c>]
+ >>> user.keywords
+ ['cheese inspector']
+ >>> user.keywords.append('snack ninja')
+ >>> user.keywords
+ ['cheese inspector', 'snack ninja']
+ >>> user.kw
+ [<__main__.Keyword object at 0x9272a4c>, <__main__.Keyword object at 0xb7b396ec>]
- # create mappers normally
- # note that we set up 'keyword_associations' on Article,
- # and 'keyword' on KeywordAssociation.
- mapper(Article, articles_table, properties={
- 'keyword_associations':relation(KeywordAssociation, lazy=False, cascade="all, delete-orphan")
- }
- )
- mapper(KeywordAssociation, itemkeywords_table,
- primary_key=[itemkeywords_table.c.article_id, itemkeywords_table.c.keyword_id],
- properties={
- 'keyword' : relation(Keyword, lazy=False),
- 'user' : relation(User, lazy=False)
- }
+The proxy is read/write. New associated objects are created on demand when values are added to the proxy, and modifying or removing an entry through the proxy also affects the underlying collection.
+
+- The association proxy property is backed by a mapper-defined relation, either a collection or scalar.
+- You can access and modify both the proxy and the backing relation. Changes in one are immediate in the other.
+- The proxy acts like the type of the underlying collection. A list gets a list-like proxy, a dict a dict-like proxy, and so on.
+- Multiple proxies for the same relation are fine.
+- Proxies are lazy, and won't triger a load of the backing relation until they are accessed.
+- The relation is inspected to determine the type of the related objects.
+- To construct new instances, the type is called with the value being assigned, or key and value for dicts.
+- A ``creator`` function can be used to create instances instead.
+
+Above, the ``Keyword.__init__`` takes a single argument ``keyword``, which maps conveniently to the value being set through the proxy. A ``creator`` function could have been used instead if more flexiblity was required.
+
+Because the proxies are backed a regular relation collection, all of the usual hooks and patterns for using collections are still in effect. The most convenient behavior is the automatic setting of "parent"-type relationships on assignment. In the example above, nothing special had to be done to associate the Keyword to the User. Simply adding it to the collection is sufficient.
+
+#### Simplifying Association Object Relations
+
+Association proxies are also useful for keeping [association objects](rel:datamapping_association) out the way during regular use. For example, the ``userkeywords`` table might have a bunch of auditing columns that need to get updated when changes are made- columns that are updated but seldom, if ever, accessed in your application. A proxy can provide a very natural access pattern for the relation.
+
+ {python}
+ from sqlalchemy.ext.associationproxy import association_proxy
+
+ # users_table and keywords_table tables as above, then:
+
+ userkeywords_table = Table('userkeywords', metadata,
+ Column('user_id', Integer, ForeignKey("users.id"), primary_key=True),
+ Column('keyword_id', Integer, ForeignKey("keywords.id"), primary_key=True),
+ # add some auditing columns
+ Column('updated_at', DateTime, default=datetime.now),
+ Column('updated_by', Integer, default=get_current_uid, onupdate=get_current_uid),
)
- mapper(User, users_table)
+
+ def _create_uk_by_keyword(keyword):
+ """A creator function."""
+ return UserKeyword(keyword=keyword)
+
+ class User(object):
+ def __init__(self, name):
+ self.name = name
+ keywords = association_proxy('user_keywords', 'keyword', creator=_create_uk_by_keyword)
+
+ class Keyword(object):
+ def __init__(self, keyword):
+ self.keyword = keyword
+ def __repr__(self):
+ return 'Keyword(%s)' % repr(self.keyword)
+
+ class UserKeyword(object):
+ def __init__(self, user=None, keyword=None):
+ self.user = user
+ self.keyword = keyword
+
+ mapper(User, users_table, properties={
+ 'user_keywords': relation(UserKeyword)
+ })
mapper(Keyword, keywords_table)
+ mapper(UserKeyword, userkeywords_table, properties={
+ 'user': relation(User),
+ 'keyword': relation(Keyword),
+ })
- # now, Keywords can be attached to an Article directly;
- # KeywordAssociation will be created by the association_proxy, and have the
- # 'keyword' attribute set to the new Keyword.
- # note that these KeywordAssociation objects will not have a User attached to them.
- article = Article()
- article.keywords.append(Keyword('blue'))
- article.keywords.append(Keyword('red'))
- session.save(article)
- session.flush()
-
- # the "keywords" collection also returns the underlying Keyword objects
- article = session.query(Article).get_by(id=12)
- for k in article.keywords:
- print "Keyword:", k.keyword_name
- # the original 'keyword_associations' relation exists normally with no awareness of the proxy
- article.keyword_associations.append(KeywordAssociation())
- print [ka for ka in article.keyword_associations]
-
-Note that the above operations on the `keywords` collection are proxying operations to and from the `keyword_associations` collection, which exists normally and can be accessed directly. `association_proxy` will also detect if the collection is list or scalar based and will configure the proxied property to act the same way.
+ user = User('log')
+ kw1 = Keyword('new_from_blammo')
-For the common case where the association object's creation needs to be specified by the application, `association_proxy` takes an optional callable `creator()` which takes a single associated object as an argument, and returns a new association object.
+ # Adding a Keyword requires creating a UserKeyword association object
+ user.user_keywords.append(UserKeyword(user, kw1))
+
+ # And accessing Keywords requires traverrsing UserKeywords
+ print user.user_keywords[0]
+ # <__main__.UserKeyword object at 0xb79bbbec>
+
+ print user.user_keywords[0].keyword
+ # Keyword('new_from_blammo')
+
+ # Lots of work.
+
+ # It's much easier to go through the association proxy!
+ for kw in (Keyword('its_big'), Keyword('its_heavy'), Keyword('its_wood')):
+ user.keywords.append(kw)
+
+ print user.keywords
+ # [Keyword('new_from_blammo'), Keyword('its_big'), Keyword('its_heavy'), Keyword('its_wood')]
+
+
+#### Building Complex Views
{python}
- def create_keyword_association(keyword):
- ka = KeywordAssociation()
- ka.keyword = keyword
- return ka
-
- class Article(object):
- # create "keywords" proxied association
- keywords = association_proxy('keyword_associations', 'keyword', creator=create_keyword_association)
+ stocks = Table("stocks", meta,
+ Column('symbol', String(10), primary_key=True),
+ Column('description', String(100), nullable=False),
+ Column('last_price', Numeric)
+ )
+
+ brokers = Table("brokers", meta,
+ Column('id', Integer,primary_key=True),
+ Column('name', String(100), nullable=False)
+ )
+
+ holdings = Table("holdings", meta,
+ Column('broker_id', Integer, ForeignKey('brokers.id'), primary_key=True),
+ Column('symbol', String(10), ForeignKey('stocks.symbol'), primary_key=True),
+ Column('shares', Integer)
+ )
+
+Above are three tables, modeling stocks, their brokers and the number of shares of a stock held by each broker. This situation is quite different from the association example above. `shares` is a _property of the relation_, an important one that we need to use all the time.
-Proxy properties are implemented by the `AssociationProxy` class, which is
-also available in the module. The `association_proxy` function is not present
-in SQLAlchemy versions 0.3.1 through 0.3.7, instead instantiate the class
-directly:
+For this example, it would be very convenient if `Broker` objects had a dictionary collection that mapped `Stock` instances to the shares held for each. That's easy.
+
+ {python}
+ from sqlalchemy.ext.associationproxy import association_proxy
+ from sqlalchemy.orm.collections import attribute_mapped_collection
+
+ def _create_holding(stock, shares):
+ """A creator function, constructs Holdings from Stock and share quantity."""
+ return Holding(stock=stock, shares=shares)
+
+ class Broker(object):
+ def __init__(self, name):
+ self.name = name
+
+ holdings = association_proxy('by_stock', 'shares', creator=_create_holding)
+
+ class Stock(object):
+ def __init__(self, symbol, description=None):
+ self.symbol = symbol
+ self.description = description
+ self.last_price = 0
+
+ class Holding(object):
+ def __init__(self, broker=None, stock=None, shares=0):
+ self.broker = broker
+ self.stock = stock
+ self.shares = shares
+
+ mapper(Stock, stocks_table)
+ mapper(Broker, brokers_table, properties={
+ 'by_stock': relation(Holding,
+ collection_class=attribute_mapped_collection('stock'))
+ })
+ mapper(Holding, holdings_table, properties={
+ 'stock': relation(Stock),
+ 'broker': relation(Broker)
+ })
+
+Above, we've set up the 'by_stock' relation collection to act as a dictionary, using the `.stock` property of each Holding as a key.
+
+Populating and accessing that dictionary manually is slightly inconvenient because of the complexity of the Holdings association object:
+
+ {python}
+ stock = Stock('ZZK')
+ broker = Broker('paj')
+
+ broker.holdings[stock] = Holding(broker, stock, 10)
+ print broker.holdings[stock].shares
+ # 10
+
+The `by_stock` proxy we've added to the `Broker` class hides the details of the `Holding` while also giving access to `.shares`:
+
+ {python}
+ for stock in (Stock('JEK'), Stock('STPZ')):
+ broker.holdings[stock] = 123
+
+ for stock, shares in broker.holdings.items():
+ print stock, shares
+
+ # lets take a peek at that holdings_table after committing changes to the db
+ print list(holdings_table.select().execute())
+ # [(1, 'ZZK', 10), (1, 'JEK', 123), (1, 'STEPZ', 123)]
+
+Further examples can be found in the `examples/` directory in the SQLAlchemy distribution.
+
+The `association_proxy` convenience function is not present in SQLAlchemy versions 0.3.1 through 0.3.7, instead instantiate the class directly:
{python}
from sqlalchemy.ext.associationproxy import AssociationProxy
@@ -374,7 +539,7 @@ directly:
`orderinglist` is a helper for mutable ordered relations. It will intercept
list operations performed on a relation collection and automatically
synchronize changes in list position with an attribute on the related objects.
-(See [advdatamapping_properties_customlist](rel:advdatamapping_properties_customlist) for more information on the general pattern.)
+(See [advdatamapping_properties_entitycollections](rel:advdatamapping_properties_customcollections) for more information on the general pattern.)
Example: Two tables that store slides in a presentation. Each slide
has a number of bullet points, displayed in order by the 'position'
@@ -447,81 +612,6 @@ documentation](rel:docstrings_sqlalchemy.ext.orderinglist) for more
information, and also check out the unit tests for examples of stepped
numbering, alphabetical and Fibonacci numbering.
-### threadlocal
-
-**Author:** Mike Bayer and Daniel Miller
-
-`threadlocal` is an extension that was created primarily to provide backwards compatibility with the older SQLAlchemy 0.1 series. It uses three features which SQLAlchemy 0.2 and above provide as distinct features: `SessionContext`, `assign_mapper`, and the `TLEngine`, which is the `Engine` used with the threadlocal `create_engine()` strategy. It is **strongly** recommended that these three features are understood individually before using threadlocal.
-
-In SQLAlchemy 0.1, users never dealt with explcit connections and didn't have a very explicit `Session` interface, instead relying upon a more magical global object called `objectstore`. The `objectstore` idea was wildly popular with about half of SA's users, and completely unpopular with the other half. The threadlocal mod basically brings back `objectstore`, which is in fact just a `SessionContext` where you can call `Session` methods directly off of it, instead of saying `context.current`. For `threadlocal` to faithfully produce 0.1 behavior, it is invoked as a *mod* which globally installs the objectstore's mapper extension, such that all `Mapper`s will automatically assign all new instances of mapped classes to the objectstore's contextual `Session`. Additionally, it also changes the default engine strategy used by `create_engine` to be the "threadlocal" strategy, which in normal practice does not affect much.
-
-When you import threadlocal, what you get is:
-
-* the "objectstore" session context object is now added to the `sqlalchemy` namespace.
-* a global `MapperExtension` is set up for all mappers which assigns "objectstore"'s session as the default session context, used by new instances as well as `Query` objects (see the section [plugins_sessioncontext_sessioncontextext](rel:plugins_sessioncontext_sessioncontextext)).
-* a new function "assign_mapper" is added to the `sqlalchemy` namespace, which calls the `assignmapper` mapper function using the new "objectstore" context.
-* the `create_engine` function is modified so that "threadlocal", and not "plain", is the default engine strategy.
-
-So an important point to understand is, **don't use the threadlocal mod unless you explcitly are looking for that behavior**. Unfortunately, the easy import of the "threadlocal" mod has found its way into several tutorials on external websites, which produces application-wide behavior that is in conflict with the SQLAlchemy tutorial and data mapping documentation.
-
-While "threadlocal" is only about 10 lines of code, it is strongly advised that users instead make usage of `SessionContext` and `assign_mapper` explictly to eliminate confusion. Additionally, the "threadlocal" strategy on `create_engine()` also exists primarily to provide patterns used in 0.1 and is probably not worth using either, unless you specifically need those patterns.
-
-Basic usage of threadlocal involves importing the mod, *before* any usage of the `sqlalchemy` namespace, since threadlocal is going to add the "objectstore" and "assign_mapper" keywords to "sqlalchemy".
-
-To use `objectstore`:
-
- {python}
- import sqlalchemy.mods.threadlocal
- from sqlalchemy import *
-
- metadata = MetaData('sqlite:///')
- user_table = Table('users', metadata,
- Column('user_id', Integer, primary_key=True),
- Column('user_name', String(50), nullable=False)
- )
-
- class User(object):
- pass
- mapper(User, user_table)
-
- # "user" object is added to the session automatically
- user = User()
-
- # flush the contextual session
- objectstore.flush()
-
-The actual `Session` is available as:
-
- {python}
- objectstore.get_session()
-
-To use `assign_mapper`:
-
- {python}
- import sqlalchemy.mods.threadlocal
- from sqlalchemy import *
-
- metadata = MetaData('sqlite:///')
- user_table = Table('users', metadata,
- Column('user_id', Integer, primary_key=True),
- Column('user_name', String(50), nullable=False)
- )
-
- class User(object):
- pass
-
- # note that no "context" object is needed
- assign_mapper(User, user_table)
-
- # call up a user
- user = User.selectfirst(user_table.c.user_id==7)
-
- # call 'delete' on the user
- user.delete()
-
- # flush
- objectstore.flush()
-
### ActiveMapper
@@ -598,15 +688,3 @@ SqlSoup creates mapped classes on the fly from tables, which are automatically r
Full SqlSoup documentation is on the [SQLAlchemy Wiki](http://www.sqlalchemy.org/trac/wiki/SqlSoup).
-### ProxyEngine
-
-**Author:** Jason Pellerin
-
-The `ProxyEngine` is used to "wrap" an `Engine`, and via subclassing `ProxyEngine` one can instrument the functionality of an arbitrary `Engine` instance through the decorator pattern. It also provides a `connect()` method which will send all `Engine` requests to different underlying engines. Its functionality in that regard is largely superceded now by `MetaData` which is a better solution.
-
- {python}
- from sqlalchemy.ext.proxy import ProxyEngine
- proxy = ProxyEngine()
-
- proxy.connect('postgres://user:pw@host/db')
-
diff --git a/doc/build/content/sqlconstruction.txt b/doc/build/content/sqlconstruction.txt
index 2c3a68aea..a672fb5ce 100644
--- a/doc/build/content/sqlconstruction.txt
+++ b/doc/build/content/sqlconstruction.txt
@@ -318,7 +318,7 @@ Select statements can also generate a WHERE clause based on the parameters you g
#### Operators {@name=operators}
-Supported column operators so far are all the numerical comparison operators, i.e. '==', '>', '>=', etc., as well as `like()`, `startswith()`, `endswith()`, `between()`, and `in_()`. Boolean operators include `not_()`, `and_()` and `or_()`, which also can be used inline via '~', '&amp;', and '|'. Math operators are '+', '-', '*', '/'. Any custom operator can be specified via the `op()` function shown below.
+Supported column operators so far are all the numerical comparison operators, i.e. '==', '>', '>=', etc., as well as `like()`, `startswith()`, `endswith()`, `between()`, and `in()`. Boolean operators include `not_()`, `and_()` and `or_()`, which also can be used inline via '~', '&amp;', and '|'. Math operators are '+', '-', '*', '/'. Any custom operator can be specified via the `op()` function shown below.
{python}
# "like" operator
diff --git a/doc/build/content/tutorial.txt b/doc/build/content/tutorial.txt
index 93f3c5064..d2077043d 100644
--- a/doc/build/content/tutorial.txt
+++ b/doc/build/content/tutorial.txt
@@ -111,7 +111,7 @@ With `metadata` as our established home for tables, lets make a Table for it:
As you might have guessed, we have just defined a table named `users` which has three columns: `user_id` (which is a primary key column), `user_name` and `password`. Currently it is just an object that doesn't necessarily correspond to an existing table in our database. To actually create the table, we use the `create()` method. To make it interesting, we will have SQLAlchemy echo the SQL statements it sends to the database, by setting the `echo` flag on the `Engine` associated with our `MetaData`:
{python}
- >>> metadata.engine.echo = True
+ >>> metadata.bind.echo = True
>>> users_table.create() # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
CREATE TABLE users (
user_id INTEGER NOT NULL,
@@ -137,7 +137,7 @@ Inserting is achieved via the `insert()` method, which defines a *clause object*
{python}
>>> i = users_table.insert()
>>> i # doctest:+ELLIPSIS
- <sqlalchemy.sql._Insert object at 0x...>
+ <sqlalchemy.sql.Insert object at 0x...>
>>> # the string form of the Insert object is a generic SQL representation
>>> print i
INSERT INTO users (user_id, user_name, password) VALUES (?, ?, ?)
@@ -210,7 +210,7 @@ You can see that when we print out the rows returned by an execution result, it
{python}
>>> row.keys()
- ['user_id', 'user_name', 'password']
+ [u'user_id', u'user_name', u'password']
>>> row['user_id'], row[1], row[users_table.c.password]
(4, u'Harry', None)
@@ -333,7 +333,7 @@ The Session has all kinds of methods on it to manage and inspect its collection
{python}
>>> query = session.query(User)
>>> print query.filter_by(user_name='Harry').all()
- SELECT users.user_name AS users_user_name, users.password AS users_password, users.user_id AS users_user_id
+ SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, users.password AS users_password
FROM users
WHERE users.user_name = ? ORDER BY users.oid
['Harry']
@@ -344,7 +344,7 @@ All querying for objects is performed via an instance of `Query`. The various `
Lets turn off the database echoing for a moment, and try out a few methods on `Query`. The two methods used to narrow results are `filter()` and `filter_by()`, and the two most common methods used to load results are `all()` and `first()`. The `get()` method is used for a quick lookup by primary key. `filter_by()` works with keyword arguments, and `filter()` works with `ClauseElement` objects, which are constructed by using `Column` objects inside of Python expressions, in the same way as we did with our SQL select example in the previous section of this tutorial. Using `ClauseElement` structures to query objects is more verbose but more flexible:
{python}
- >>> metadata.engine.echo = False
+ >>> metadata.bind.echo = False
>>> print query.filter(User.c.user_id==3).all()
[User(u'Fred',None)]
>>> print query.get(2)
@@ -400,7 +400,7 @@ With a new user "ed" and some changes made on "Mary" and "Harry", lets also mark
Then to send all of our changes to the database, we `flush()` the Session. Lets turn echo back on to see this happen!:
{python}
- >>> metadata.engine.echo = True
+ >>> metadata.bind.echo = True
>>> session.flush()
BEGIN
UPDATE users SET password=? WHERE users.user_id = ?
@@ -444,6 +444,10 @@ We then create a mapper for the `User` class which contains a relationship to th
... })
<sqlalchemy.orm.mapper.Mapper object at 0x...>
+Since we've made new mappers, we have to throw away the old `Query` object and get a new one:
+
+ >>> query = session.query(User)
+
The `relation()` function takes either a class or a Mapper as its first argument, and has many options to further control its behavior. When this mapping relationship is used, each new `User` instance will contain an attribute called `addresses`. SQLAlchemy will automatically determine that this relationship is a one-to-many relationship, and will subsequently create `addresses` as a list. When a new `User` is created, this list will begin as empty.
The order in which the mapping definitions for `User` and `Address` is created is *not significant*. When the `mapper()` function is called, it creates an *uncompiled* mapping record corresponding to the given class/table combination. When the mappers are first used, the entire collection of mappers created up until that point will be compiled, which involves the establishment of class instrumentation as well as the resolution of all mapping relationships.
@@ -457,13 +461,13 @@ We can then treat the `addresses` attribute on each `User` object like a regular
{python}
>>> mary = query.filter_by(user_name='Mary').first() # doctest: +NORMALIZE_WHITESPACE
- SELECT users.user_name AS users_user_name, users.password AS users_password, users.user_id AS users_user_id
+ SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, users.password AS users_password
FROM users
WHERE users.user_name = ? ORDER BY users.oid
LIMIT 1 OFFSET 0
['Mary']
>>> print [a for a in mary.addresses]
- SELECT email_addresses.user_id AS email_addresses_user_id, email_addresses.address_id AS email_addresses_address_id, email_addresses.email_address AS email_addresses_email_address
+ SELECT email_addresses.address_id AS email_addresses_address_id, email_addresses.email_address AS email_addresses_email_address, email_addresses.user_id AS email_addresses_user_id
FROM email_addresses
WHERE ? = email_addresses.user_id ORDER BY email_addresses.oid
[1]
@@ -485,50 +489,22 @@ Main documentation for using mappers: [datamapping](rel:datamapping)
### Transactions
-You may have noticed from the example above that when we say `session.flush()`, SQLAlchemy indicates the names `BEGIN` and `COMMIT` to indicate a transaction with the database. The `flush()` method, since it may execute many statements in a row, will automatically use a transaction in order to execute these instructions. But what if we want to use `flush()` inside of a larger transaction? This is performed via the `SessionTransaction` object, which we can establish using `session.create_transaction()`. Below, we will perform a more complicated `SELECT` statement, make several changes to our collection of users and email addresess, and then create a new user with two email addresses, within the context of a transaction. We will perform a `flush()` in the middle of it to write the changes we have so far, and then allow the remaining changes to be written when we finally `commit()` the transaction. We enclose our operations within a `try/except` block to ensure that resources are properly freed:
+You may have noticed from the example above that when we say `session.flush()`, SQLAlchemy indicates the names `BEGIN` and `COMMIT` to indicate a transaction with the database. The `flush()` method, since it may execute many statements in a row, will automatically use a transaction in order to execute these instructions. But what if we want to use `flush()` inside of a larger transaction? The easiest way is to use a "transactional" session; that is, when the session is created, you're automatically in a transaction which you can commit or rollback at any time. As a bonus, it offers the ability to call `flush()` for you, whenever a query is issued; that way whatever changes you've made can be returned right back (and since its all in a transaction, nothing gets committed until you tell it so).
+
+Below, we create a session with `autoflush=True`, which implies that it's transactional. We can query for things as soon as they are created without the need for calling `flush()`. At the end, we call `commit()` to persist everything permanently.
{python}
- >>> transaction = session.create_transaction()
- >>> try: # doctest: +NORMALIZE_WHITESPACE
- ... (ed, harry, mary) = session.query(User).filter(
+ >>> metadata.bind.echo = False
+ >>> session = create_session(autoflush=True)
+ >>> (ed, harry, mary) = session.query(User).filter(
... User.c.user_name.in_('Ed', 'Harry', 'Mary')
- ... ).order_by(User.c.user_name).all()
- ... del mary.addresses[1]
- ... harry.addresses.append(Address('harry2@gmail.com'))
- ... session.flush()
- ... print "***flushed the session***"
- ... fred = User()
- ... fred.user_name = 'fred_again'
- ... fred.addresses.append(Address('fred@fred.com'))
- ... fred.addresses.append(Address('fredsnewemail@fred.com'))
- ... session.save(fred)
- ... transaction.commit()
- ... except:
- ... transaction.rollback()
- ... raise
- BEGIN
- SELECT users.user_name AS users_user_name, users.password AS users_password, users.user_id AS users_user_id
- FROM users
- WHERE users.user_name IN (?, ?, ?) ORDER BY users.user_name
- ['Ed', 'Harry', 'Mary']
- SELECT email_addresses.user_id AS email_addresses_user_id, email_addresses.address_id AS email_addresses_address_id, email_addresses.email_address AS email_addresses_email_address
- FROM email_addresses
- WHERE ? = email_addresses.user_id ORDER BY email_addresses.oid
- [4]
- UPDATE email_addresses SET user_id=? WHERE email_addresses.address_id = ?
- [None, 3]
- INSERT INTO email_addresses (email_address, user_id) VALUES (?, ?)
- ['harry2@gmail.com', 4]
- ***flushed the session***
- INSERT INTO users (user_name, password) VALUES (?, ?)
- ['fred_again', None]
- INSERT INTO email_addresses (email_address, user_id) VALUES (?, ?)
- ['fred@fred.com', 6]
- INSERT INTO email_addresses (email_address, user_id) VALUES (?, ?)
- ['fredsnewemail@fred.com', 6]
- COMMIT
-
-The `SessionTransaction` process above is due to be greatly simplified in version 0.4 of SQLAlchemy, where the `Session` will be able to wrap its whole lifespan in a transaction automatically.
+ ... ).order_by(User.c.user_name).all() # doctest: +NORMALIZE_WHITESPACE
+ >>> del mary.addresses[1]
+ >>> harry_address = Address('harry2@gmail.com')
+ >>> harry.addresses.append(harry_address)
+ >>> session.query(User).join('addresses').filter_by(email_address='harry2@gmail.com').first() # doctest: +NORMALIZE_WHITESPACE
+ User(u'Harry',u'harrysnewpassword')
+ >>> session.commit()
Main documentation: [unitofwork](rel:unitofwork)
diff --git a/doc/build/content/unitofwork.txt b/doc/build/content/unitofwork.txt
index 4ae2c3c91..ef0118901 100644
--- a/doc/build/content/unitofwork.txt
+++ b/doc/build/content/unitofwork.txt
@@ -14,13 +14,11 @@ SQLAlchemy's unit of work includes these functions:
* The ability to maintain and process a list of modified objects, and based on the relationships set up by the mappers for those objects as well as the foreign key relationships of the underlying tables, figure out the proper order of operations so that referential integrity is maintained, and also so that on-the-fly values such as newly created primary keys can be propigated to dependent objects that need them before they are saved. The central algorithm for this is the *topological sort*.
* The ability to define custom functionality that occurs within the unit-of-work flush phase, such as "before insert", "after insert", etc. This is accomplished via MapperExtension.
* an Identity Map, which is a dictionary storing the one and only instance of an object for a particular table/primary key combination. This allows many parts of an application to get a handle to a particular object without any chance of modifications going to two different places.
-* The sole interface to the unit of work is provided via the `Session` object. Transactional capability, which rides on top of the transactions provided by `Engine` objects, is provided by the `SessionTransaction` object.
-* Thread-locally scoped Session behavior is available as an option, which allows new objects to be automatically added to the Session corresponding to by the *default Session context*. Without a default Session context, an application must explicitly create a Session manually as well as add new objects to it. The default Session context, disabled by default, can also be plugged in with other user-defined schemes, which may also take into account the specific class being dealt with for a particular operation.
-* The Session object borrows conceptually from that of [Hibernate](http://www.hibernate.org), a leading ORM for Java that was a great influence on the creation of the [JSR-220](http://jcp.org/aboutJava/communityprocess/pfd/jsr220/index.html) specification. SQLAlchemy, under no obligation to conform to EJB specifications, is in general very different from Hibernate, providing a different paradigm for producing queries, a SQL API that is useable independently of the ORM, and of course Pythonic configuration as opposed to XML; however, JSR-220/Hibernate makes some pretty good suggestions with regards to the mechanisms of persistence.
+* The sole interface to the unit of work is provided via the `Session` object. Transactional capability is included.
### Object States {@name=states}
-When dealing with mapped instances with regards to Sessions, an instance may be *attached* or *unattached* to a particular Session. An instance also may or may not correspond to an actual row in the database. The product of these two binary conditions yields us four general states a particular instance can have within the perspective of the Session:
+When dealing with mapped instances with regards to Sessions, an instance may be *attached* or *unattached* to a particular Session. An instance also may or may not correspond to an actual row in the database. These conditions break up into four distinct states:
* *Transient* - a transient instance exists within memory only and is not associated with any Session. It also has no database identity and does not have a corresponding record in the database. When a new instance of a class is constructed, and no default session context exists with which to automatically attach the new instance, it is a transient instance. The instance can then be saved to a particular session in which case it becomes a *pending* instance. If a default session context exists, new instances are added to that Session by default and therefore become *pending* instances immediately.
@@ -216,6 +214,8 @@ It also can be called with a list of objects; in this form, the flush operation
This second form of flush should be used carefully as it will not necessarily locate other dependent objects within the session, whose database representation may have foreign constraint relationships with the objects being operated upon.
+Theres also a way to have `flush()` called automatically before each query; this is called "autoflush" and is described below.
+
##### Notes on Flush {@name=whatis}
A common misconception about the `flush()` operation is that once performed, the newly persisted instances will automatically have related objects attached to them, based on the values of primary key identities that have been assigned to the instances before they were persisted. An example would be, you create a new `Address` object, set `address.user_id` to 5, and then `flush()` the session. The erroneous assumption would be that there is now a `User` object of identity "5" attached to the `Address` object, but in fact this is not the case. If you were to `refresh()` the `Address`, invalidating its current state and re-loading, *then* it would have the appropriate `User` object present.
@@ -350,95 +350,144 @@ Note that cascading doesn't do anything that isn't possible by manually calling
The default value for `cascade` on `relation()`s is `save-update`, and the `private=True` keyword argument is a synonym for `cascade="all, delete-orphan"`.
-### SessionTransaction {@name=transaction}
+### Using Session Transactions {@name=transaction}
+
+The Session can manage transactions automatically, including across multiple engines. When the Session is in a transaction, as it receives requests to execute SQL statements, it adds each indivdual Connection/Engine encountered to its transactional state. At commit time, all unflushed data is flushed, and each individual transaction is committed. If the underlying databases support two-phase semantics, this may be used by the Session as well if two-phase transactions are enabled.
-SessionTransaction is a multi-engine transaction manager, which aggregates one or more Engine/Connection pairs and keeps track of a Transaction object for each one. As the Session receives requests to execute SQL statements, it uses the Connection that is referenced by the SessionTransaction. At commit time, the underyling Session is flushed, and each Transaction is the committed.
+The easiest way to use a Session with transactions is just to declare it as transactional. The session will remain in a transaction at all times:
-Example usage is as follows:
+ {python}
+ sess = create_session(transactional=True)
+ item1 = sess.query(Item).get(1)
+ item2 = sess.query(Item).get(2)
+ item1.foo = 'bar'
+ item2.bar = 'foo'
+
+ # commit- will immediately go into a new transaction afterwards
+ sess.commit()
+
+Alternatively, a transaction can be begun explicitly using `begin()`:
{python}
sess = create_session()
- trans = sess.create_transaction()
+ sess.begin()
try:
item1 = sess.query(Item).get(1)
item2 = sess.query(Item).get(2)
item1.foo = 'bar'
item2.bar = 'foo'
except:
- trans.rollback()
+ sess.rollback()
raise
- trans.commit()
+ sess.commit()
-The SessionTransaction object supports Python 2.5's with statement so that the example above can be written as:
+Session also supports Python 2.5's with statement so that the example above can be written as:
{python}
sess = create_session()
- with sess.create_transaction():
+ with sess.begin():
item1 = sess.query(Item).get(1)
item2 = sess.query(Item).get(2)
item1.foo = 'bar'
item2.bar = 'foo'
-The `create_transaction()` method creates a new SessionTransaction object but does not declare any connection/transaction resources. At the point of the first `get()` call, a connection resource is opened off the engine that corresponds to the Item classes' mapper and is stored within the `SessionTransaction` with an open `Transaction`. When `trans.commit()` is called, the `flush()` method is called on the `Session` and the corresponding update statements are issued to the database within the scope of the transaction already opened; afterwards, the underying Transaction is committed, and connection resources are freed.
+For MySQL and Postgres (and soon Oracle), "nested" transactions can be accomplished which use SAVEPOINT behavior, via the `begin_nested()` method:
-`SessionTransaction`, like the `Transaction` off of `Connection` also supports "nested" behavior, and is safe to pass to other functions which then issue their own `begin()`/`commit()` pair; only the outermost `begin()`/`commit()` pair actually affects the transaction, and any call to `rollback()` within a particular call stack will issue a rollback.
+ {python}
+ sess = create_session()
+ sess.begin()
+ sess.save(u1)
+ sess.save(u2)
+ sess.flush()
-Note that while SessionTransaction is capable of tracking multiple transactions across multiple databases, it currently is in no way a fully functioning two-phase commit engine; generally, when dealing with multiple databases simultaneously, there is the distinct possibility that a transaction can succeed on the first database and fail on the second, which for some applications may be an invalid state. If this is an issue, its best to either refrain from spanning transactions across databases, or to look into some of the available technologies in this area, such as [Zope](http://www.zope.org) which offers a two-phase commit engine; some users have already created their own SQLAlchemy/Zope hybrid implementations to deal with scenarios like these.
+ sess.begin_nested() # establish a savepoint
+ sess.save(u3)
+ sess.rollback() # rolls back u3, keeps u1 and u2
-SessionTransaction Facts:
+ sess.commit() # commits u1 and u2
- * SessionTransaction, like its parent Session object, is **not threadsafe**.
- * SessionTransaction will no longer be necessary in SQLAlchemy 0.4, where its functionality is to be merged with the Session itself.
-
-#### Using SQL with SessionTransaction {@name=sql}
+Finally, for MySQL, Postgres, and soon Oracle as well, the session can be instructed to use two-phase commit semantics using the flag `twophase=True`, which coordinates transactions across multiple databases:
-The SessionTransaction can interact with direct SQL queries in two general ways. Either specific `Connection` objects can be associated with the `SessionTransaction`, which are then useable both for direct SQL as well as within `flush()` operations performed by the `SessionTransaction`, or via accessing the `Connection` object automatically referenced within the `SessionTransaction`.
+ {python}
+ engine1 = create_engine('postgres://db1')
+ engine2 = create_engine('postgres://db2')
+
+ sess = create_session(twophase=True, transactional=True)
+
+ # bind User operations to engine 1
+ sess.bind_mapper(User, engine1)
+
+ # bind Account operations to engine 2
+ sess.bind_mapper(Account, engine2)
+
+ # .... work with accounts and users
+
+ # commit. session will issue a flush to all DBs, and a prepare step to all DBs,
+ # before committing both transactions
+ sess.commit()
+
+#### AutoFlush {@name=autoflush}
-To associate a specific `Connection` with the `SessionTransaction`, use the `add()` method:
+A transactional session can also conveniently issue `flush()` calls before each query. This allows you to immediately have DB access to whatever has been saved to the session. Creating the session with `autoflush=True` implies `transactional=True`:
- {python title="Associate a Connection with the SessionTransaction"}
- connection = engine.connect()
- trans = session.create_transaction()
- try:
- trans.add(connection)
- connection.execute(mytable.update(), {'col1':4, 'col2':17})
- session.flush() # flush() operation will use the same connection
- except:
- trans.rollback()
- raise
- trans.commit()
+ {python}
+ sess = create_session(autoflush=True)
+ u1 = User(name='jack')
+ sess.save(u1)
+
+ # reload user1
+ u2 = sess.query(User).filter_by(name='jack').one()
+ assert u2 is u1
-The `add()` method will key the `Connection`'s underlying `Engine` to this `SessionTransaction`. When mapper operations are performed against this `Engine`, the `Connection` explicitly added will be used. This **overrides** any other `Connection` objects that the underlying Session was associated with, corresponding to the underlying `Engine` of that `Connection`. However, if the `SessionTransaction` itself is already associated with a `Connection`, then an exception is thrown.
+ # commit session, flushes whatever is remaining
+ sess.commit()
+
+#### Using SQL with Sessions and Transactions {@name=sql}
-The other way is just to use the `Connection` referenced by the `SessionTransaction`. This is performed via the `connection()` method, and requires passing in a class or `Mapper` which indicates which underlying `Connection` should be returned (recall that different `Mappers` may use different underlying `Engines`). If the `class_or_mapper` argument is `None`, then the `Session` must be globally bound to a specific `Engine` when it was constructed, else the method returns `None`.
+SQL constructs and string statements can be executed via the `Session`. You'd want to do this normally when your `Session` is transactional and youd like your free-standing SQL statements to participate in the same transaction.
- {python title="Get a Connection from the SessionTransaction"}
- trans = session.create_transaction()
- try:
- connection = trans.connection(UserClass) # get the Connection used by the UserClass' Mapper
- connection.execute(mytable.update(), {'col1':4, 'col2':17})
- except:
- trans.rollback()
- raise
- trans.commit()
-
-The `connection()` method also exists on the `Session` object itself, and can be called regardless of whether or not a `SessionTransaction` is in progress. If a `SessionTransaction` is in progress, it will return the connection referenced by the transaction. If an `Engine` is being used with `threadlocal` strategy, the `Connection` returned will correspond to the connection resources that are bound to the current thread, if any (i.e. it is obtained by calling `contextual_connect()`).
+The two ways to do this are to use the connection/execution services of the Session, or to have your Session participate in a regular SQL transaction.
-#### Using Engine-level Transactions with Sessions
+First, a Session thats associated with an Engine or Connection can execute statements immediately (whether or not its transactional):
-The transactions issued by `SessionTransaction` as well as internally by the `Session`'s `flush()` operation use the same `Transaction` object off of `Connection` that is publically available. Recall that this object supports "nestable" behavior, meaning any number of actors can call `begin()` off a particular `Connection` object, and they will all be managed within the scope of a single transaction. Therefore, the `flush()` operation can similarly take place within the scope of a regular `Transaction`:
+ {python}
+ sess = create_session(bind=engine, transactional=True)
+ result = sess.execute("select * from table where id=:id", {'id':7})
+ result2 = sess.execute(select([mytable], mytable.c.id==7))
- {python title="Transactions with Sessions"}
- connection = engine.connect() # Connection
- session = create_session(bind=connection) # Session bound to the Connection
- trans = connection.begin() # start transaction
- try:
- stuff = session.query(MyClass).select() # Session operation uses connection
- stuff[2].foo = 'bar'
- connection.execute(mytable.insert(), dict(id=12, value="bar")) # use connection explicitly
- session.flush() # Session flushes with "connection", using transaction "trans"
- except:
- trans.rollback() # or rollback
- raise
- trans.commit() # commit
+To get at the current connection used by the session, which will be part of the current transaction if one is in progress, use `connection()`:
+
+ connection = sess.connection()
+
+A second scenario is that of a Session which is not directly bound to a connectable. This session executes statements relative to a particular `Mapper`, since the mappers are bound to tables which are in turn bound to connectables via their `MetaData` (either the session or the mapped tables need to be bound). In this case, the Session can conceivably be associated with multiple databases through different mappers; so it wants you to send along a `mapper` argument, which can be any mapped class or mapper instance:
+
+ {python}
+ sess = create_session(transactional=True)
+ result = sess.execute("select * from table where id=:id", {'id':7}, mapper=MyMappedClass)
+ result2 = sess.execute(select([mytable], mytable.c.id==7), mapper=MyMappedClass)
+
+ connection = sess.connection(MyMappedClass)
+The third scenario is when you are using `Connection` and `Transaction` yourself, and want the `Session` to participate. This is easy, as you just bind the `Session` to the connection:
+
+ {python}
+ conn = engine.connect()
+ trans = conn.begin()
+ sess = create_session(bind=conn)
+ # ... etc
+ trans.commit()
+
+It's safe to use a `Session` which is transactional or autoflushing, as well as to call `begin()`/`commit()` on the session too; the outermost Transaction object, the one we declared explicitly, controls the scope of the transaction.
+
+When using the `threadlocal` engine context, things are that much easier; the `Session` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it:
+
+ {python}
+ engine = create_engine('postgres://mydb', strategy="threadlocal")
+ engine.begin()
+
+ sess = create_session() # session takes place in the transaction like everyone else
+
+ # ... go nuts
+
+ engine.commit() # commit
+
diff --git a/doc/build/gen_docstrings.py b/doc/build/gen_docstrings.py
index 6b370fe80..346497d3e 100644
--- a/doc/build/gen_docstrings.py
+++ b/doc/build/gen_docstrings.py
@@ -3,12 +3,13 @@ import docstring
import re
from sqlalchemy import schema, types, ansisql, engine, sql, pool, orm, exceptions, databases
-import sqlalchemy.ext.proxy as proxy
+import sqlalchemy.orm.shard
import sqlalchemy.ext.sessioncontext as sessioncontext
-import sqlalchemy.mods.threadlocal as threadlocal
import sqlalchemy.ext.selectresults as selectresults
import sqlalchemy.ext.orderinglist as orderinglist
import sqlalchemy.ext.associationproxy as associationproxy
+import sqlalchemy.ext.assignmapper as assignmapper
+import sqlalchemy.ext.sqlsoup as sqlsoup
def make_doc(obj, classes=None, functions=None, **kwargs):
"""generate a docstring.ObjectDoc structure for an individual module, list of classes, and list of functions."""
@@ -21,6 +22,7 @@ def make_all_docs():
objects = [
make_doc(obj=sql,include_all_classes=True),
make_doc(obj=schema),
+ make_doc(obj=pool),
make_doc(obj=types),
make_doc(obj=engine),
make_doc(obj=engine.url),
@@ -29,18 +31,21 @@ def make_all_docs():
make_doc(obj=engine.threadlocal),
make_doc(obj=ansisql),
make_doc(obj=orm),
- make_doc(obj=orm.mapperlib, classes=[orm.mapperlib.MapperExtension, orm.mapperlib.Mapper]),
+ make_doc(obj=orm.collections, classes=[orm.collections.collection,
+ orm.collections.MappedCollection,
+ orm.collections.CollectionAdapter]),
make_doc(obj=orm.interfaces),
+ make_doc(obj=orm.mapperlib, classes=[orm.mapperlib.MapperExtension, orm.mapperlib.Mapper]),
+ make_doc(obj=orm.properties),
make_doc(obj=orm.query, classes=[orm.query.Query, orm.query.QueryContext, orm.query.SelectionContext]),
make_doc(obj=orm.session, classes=[orm.session.Session, orm.session.SessionTransaction]),
+ make_doc(obj=orm.shard),
make_doc(obj=exceptions),
- make_doc(obj=pool),
- make_doc(obj=sessioncontext),
- make_doc(obj=threadlocal),
- make_doc(obj=selectresults),
- make_doc(obj=proxy),
- make_doc(obj=orderinglist, classes=[orderinglist.OrderingList]),
+ make_doc(obj=assignmapper),
make_doc(obj=associationproxy, classes=[associationproxy.AssociationProxy]),
+ make_doc(obj=orderinglist, classes=[orderinglist.OrderingList]),
+ make_doc(obj=sessioncontext),
+ make_doc(obj=sqlsoup),
] + [make_doc(getattr(__import__('sqlalchemy.databases.%s' % m).databases, m)) for m in databases.__all__]
return objects
diff --git a/doc/build/genhtml.py b/doc/build/genhtml.py
index 3cb1f7b2c..ddc2e8a92 100644
--- a/doc/build/genhtml.py
+++ b/doc/build/genhtml.py
@@ -24,9 +24,13 @@ files = [
'types',
'pooling',
'plugins',
- 'docstrings'
+ 'docstrings',
]
+post_files = [
+ 'copyright'
+]
+
parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
parser.add_option("--file", action="store", dest="file", help="only generate file <file>")
parser.add_option("--docstrings", action="store_true", dest="docstrings", help="only generate docstrings")
@@ -34,9 +38,11 @@ parser.add_option("--version", action="store", dest="version", default=sqlalchem
(options, args) = parser.parse_args()
if options.file:
- files = [options.file]
+ to_gen = [options.file]
+else:
+ to_gen = files + post_files
-title='SQLAlchemy 0.3 Documentation'
+title='SQLAlchemy 0.4 Documentation'
version = options.version
root = toc.TOCElement('', 'root', '', version=version, doctitle=title)
@@ -46,7 +52,7 @@ shutil.copy('./content/docstrings.html', './output/docstrings.html')
shutil.copy('./content/documentation.html', './output/documentation.html')
if not options.docstrings:
- read_markdown.parse_markdown_files(root, files)
+ read_markdown.parse_markdown_files(root, [f for f in files if f in to_gen])
if not options.file or options.docstrings:
docstrings = gen_docstrings.make_all_docs()
@@ -54,8 +60,12 @@ if not options.file or options.docstrings:
pickle.dump(docstrings, file('./output/compiled_docstrings.pickle', 'w'))
- pickle.dump(root, file('./output/table_of_contents.pickle', 'w'))
+if not options.docstrings:
+ read_markdown.parse_markdown_files(root, [f for f in post_files if f in to_gen])
+if not options.file or options.docstrings:
+ pickle.dump(root, file('./output/table_of_contents.pickle', 'w'))
+
template_dirs = ['./templates', './output']
output = os.path.dirname(os.getcwd())
@@ -69,7 +79,7 @@ def genfile(name, outname):
outfile.write(t.render(attributes={}))
if not options.docstrings:
- for filename in files:
+ for filename in to_gen:
try:
genfile(filename, os.path.join(os.getcwd(), '../', filename + ".html"))
except:
diff --git a/doc/build/read_markdown.py b/doc/build/read_markdown.py
index aade38a8c..c80589fc2 100644
--- a/doc/build/read_markdown.py
+++ b/doc/build/read_markdown.py
@@ -1,10 +1,12 @@
"""loads Markdown files, converts each one to HTML and parses the HTML into an ElementTree structure.
The collection of ElementTrees are further parsed to generate a table of contents structure, and are
- manipulated to replace various markdown-generated HTML with specific Myghty tags before being written
- to Myghty templates, which then re-access the table of contents structure at runtime.
+ manipulated to replace various markdown-generated HTML with specific Mako tags before being written
+ to Mako templates, which then re-access the table of contents structure at runtime.
Much thanks to Alexey Shamrin, who came up with the original idea and did all the heavy Markdown/Elementtree
-lifting for this module."""
+lifting for this module.
+"""
+
import sys, re, os
from toc import TOCElement
diff --git a/examples/adjacencytree/basic_tree.py b/examples/adjacencytree/basic_tree.py
index 9676fae89..53bdc8298 100644
--- a/examples/adjacencytree/basic_tree.py
+++ b/examples/adjacencytree/basic_tree.py
@@ -1,9 +1,12 @@
"""a basic Adjacency List model tree."""
from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.util import OrderedDict
+from sqlalchemy.orm.collections import attribute_mapped_collection
-metadata = MetaData('sqlite:///', echo=True)
+metadata = MetaData('sqlite:///')
+metadata.bind.echo = True
trees = Table('treenodes', metadata,
Column('node_id', Integer, Sequence('treenode_id_seq',optional=False), primary_key=True),
@@ -11,17 +14,10 @@ trees = Table('treenodes', metadata,
Column('node_name', String(50), nullable=False),
)
-class NodeList(OrderedDict):
- """subclasses OrderedDict to allow usage as a list-based property."""
- def append(self, node):
- self[node.name] = node
- def __iter__(self):
- return iter(self.values())
class TreeNode(object):
"""a rich Tree class which includes path-based operations"""
def __init__(self, name):
- self.children = NodeList()
self.name = name
self.parent = None
self.id = None
@@ -30,7 +26,7 @@ class TreeNode(object):
if isinstance(node, str):
node = TreeNode(node)
node.parent = self
- self.children.append(node)
+ self.children[node.name] = node
def __repr__(self):
return self._getstring(0, False)
def __str__(self):
@@ -47,7 +43,7 @@ mapper(TreeNode, trees, properties=dict(
id=trees.c.node_id,
name=trees.c.node_name,
parent_id=trees.c.parent_node_id,
- children=relation(TreeNode, cascade="all", backref=backref("parent", remote_side=[trees.c.node_id]), collection_class=NodeList),
+ children=relation(TreeNode, cascade="all", backref=backref("parent", remote_side=[trees.c.node_id]), collection_class=attribute_mapped_collection('name')),
))
print "\n\n\n----------------------------"
diff --git a/examples/adjacencytree/byroot_tree.py b/examples/adjacencytree/byroot_tree.py
index 5ec055392..a61bde875 100644
--- a/examples/adjacencytree/byroot_tree.py
+++ b/examples/adjacencytree/byroot_tree.py
@@ -3,7 +3,8 @@ introduces a new selection method which selects an entire tree of nodes at once,
advantage of a custom MapperExtension to assemble incoming nodes into their correct structure."""
from sqlalchemy import *
-from sqlalchemy.util import OrderedDict
+from sqlalchemy.orm import *
+from sqlalchemy.orm.collections import attribute_mapped_collection
engine = create_engine('sqlite:///:memory:', echo=True)
@@ -28,82 +29,69 @@ treedata = Table(
)
-class NodeList(OrderedDict):
- """subclasses OrderedDict to allow usage as a list-based property."""
- def append(self, node):
- self[node.name] = node
- def __iter__(self):
- return iter(self.values())
-
-
class TreeNode(object):
"""a hierarchical Tree class, which adds the concept of a "root node". The root is
the topmost node in a tree, or in other words a node whose parent ID is NULL.
All child nodes that are decendents of a particular root, as well as a root node itself,
- reference this root node.
- this is useful as a way to identify all nodes in a tree as belonging to a single
- identifiable root. Any node can return its root node and therefore the "tree" that it
- belongs to, and entire trees can be selected from the database in one query, by
- identifying their common root ID."""
+ reference this root node. """
def __init__(self, name):
- """for data integrity, a TreeNode requires its name to be passed as a parameter
- to its constructor, so there is no chance of a TreeNode that doesnt have a name."""
self.name = name
- self.children = NodeList()
self.root = self
- self.parent = None
- self.id = None
- self.data =None
- self.parent_id = None
- self.root_id=None
+
def _set_root(self, root):
self.root = root
- for c in self.children:
+ for c in self.children.values():
c._set_root(root)
+
def append(self, node):
if isinstance(node, str):
node = TreeNode(node)
- node.parent = self
node._set_root(self.root)
self.children.append(node)
+
def __repr__(self):
return self._getstring(0, False)
+
def __str__(self):
return self._getstring(0, False)
+
def _getstring(self, level, expand = False):
s = (' ' * level) + "%s (%s,%s,%s, %d): %s" % (self.name, self.id,self.parent_id,self.root_id, id(self), repr(self.data)) + '\n'
if expand:
s += ''.join([n._getstring(level+1, True) for n in self.children.values()])
return s
+
def print_nodes(self):
return self._getstring(0, True)
class TreeLoader(MapperExtension):
- """an extension that will plug-in additional functionality to the Mapper."""
+
def after_insert(self, mapper, connection, instance):
"""runs after the insert of a new TreeNode row. The primary key of the row is not determined
until the insert is complete, since most DB's use autoincrementing columns. If this node is
the root node, we will take the new primary key and update it as the value of the node's
"root ID" as well, since its root node is itself."""
+
if instance.root is instance:
connection.execute(mapper.mapped_table.update(TreeNode.c.id==instance.id, values=dict(root_node_id=instance.id)))
instance.root_id = instance.id
- def append_result(self, mapper, selectcontext, row, instance, identitykey, result, isnew):
+ def append_result(self, mapper, selectcontext, row, instance, result, **flags):
"""runs as results from a SELECT statement are processed, and newly created or already-existing
instances that correspond to each row are appended to result lists. This method will only
append root nodes to the result list, and will attach child nodes to their appropriate parent
node as they arrive from the select results. This allows a SELECT statement which returns
both root and child nodes in one query to return a list of "roots"."""
+
+ isnew = flags.get('isnew', False)
+
if instance.parent_id is None:
result.append(instance)
else:
- if isnew or context.populate_existing:
+ if isnew or selectcontext.populate_existing:
parentnode = selectcontext.identity_map[mapper.identity_key(instance.parent_id)]
- parentnode.children.append_without_event(instance)
- # fire off lazy loader before the instance is part of the session
- instance.children
+ parentnode.children.append(instance)
return False
class TreeData(object):
@@ -127,12 +115,19 @@ mapper(TreeNode, trees, properties=dict(
name=trees.c.node_name,
parent_id=trees.c.parent_node_id,
root_id=trees.c.root_node_id,
- root=relation(TreeNode, primaryjoin=trees.c.root_node_id==trees.c.node_id, remote_side=trees.c.node_id, lazy=None, uselist=False),
- children=relation(TreeNode, primaryjoin=trees.c.parent_node_id==trees.c.node_id, lazy=None, uselist=True, cascade="delete,save-update", collection_class=NodeList),
- data=relation(mapper(TreeData, treedata, properties=dict(id=treedata.c.data_id)), cascade="delete,delete-orphan,save-update", lazy=False)
+ root=relation(TreeNode, primaryjoin=trees.c.root_node_id==trees.c.node_id, remote_side=trees.c.node_id, lazy=None),
+ children=relation(TreeNode,
+ primaryjoin=trees.c.parent_node_id==trees.c.node_id,
+ lazy=None,
+ cascade="all",
+ collection_class=attribute_mapped_collection('name'),
+ backref=backref('parent', primaryjoin=trees.c.parent_node_id==trees.c.node_id, remote_side=trees.c.node_id)
+ ),
+ data=relation(TreeData, cascade="all, delete-orphan", lazy=False)
), extension = TreeLoader())
+mapper(TreeData, treedata, properties={'id':treedata.c.data_id})
session = create_session()
diff --git a/examples/backref/backref_tree.py b/examples/backref/backref_tree.py
deleted file mode 100644
index 7386d034c..000000000
--- a/examples/backref/backref_tree.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from sqlalchemy import *
-
-metadata = MetaData('sqlite:///', echo=True)
-
-class Tree(object):
- def __init__(self, name='', father=None):
- self.name = name
- self.father = father
- def __str__(self):
- return '<TreeNode: %s>' % self.name
- def __repr__(self):
- return self.__str__()
-
-table = Table('tree', metadata,
- Column('id', Integer, primary_key=True),
- Column('name', String(64), nullable=False),
- Column('father_id', Integer, ForeignKey('tree.id'), nullable=True))
-table.create()
-
-mapper(Tree, table,
- properties={
- 'childs':relation(Tree, remote_side=table.c.father_id, primaryjoin=table.c.father_id==table.c.id, backref=backref('father', remote_side=table.c.id))},
- )
-
-root = Tree('root')
-child1 = Tree('child1', root)
-child2 = Tree('child2', root)
-child3 = Tree('child3', child1)
-
-child4 = Tree('child4')
-child1.childs.append(child4)
-
-session = create_session()
-session.save(root)
-session.flush()
-
-print root.childs
-print child1.childs
-print child2.childs
-print child2.father
-print child3.father
diff --git a/examples/collections/large_collection.py b/examples/collections/large_collection.py
index d592441ab..3c53db121 100644
--- a/examples/collections/large_collection.py
+++ b/examples/collections/large_collection.py
@@ -1,7 +1,8 @@
"""illlustrates techniques for dealing with very large collections"""
from sqlalchemy import *
-meta = MetaData('sqlite://', echo=True)
+meta = MetaData('sqlite://')
+meta.bind.echo = True
org_table = Table('organizations', meta,
Column('org_id', Integer, primary_key=True),
diff --git a/examples/elementtree/adjacency_list.py b/examples/elementtree/adjacency_list.py
new file mode 100644
index 000000000..204662f56
--- /dev/null
+++ b/examples/elementtree/adjacency_list.py
@@ -0,0 +1,215 @@
+"""illustrates an explicit way to persist an XML document expressed using ElementTree.
+
+This example explicitly marshals/unmarshals the ElementTree document into
+mapped entities which have their own tables. Compare to pickle.py which
+uses pickle to accomplish the same task. Note that the usage of both
+styles of persistence are identical, as is the structure of the main Document class.
+"""
+
+################################# PART I - Imports/Coniguration ###########################################
+from sqlalchemy import *
+from sqlalchemy.orm import *
+
+import sys, os, StringIO, re
+
+import logging
+logging.basicConfig()
+
+# uncomment to show SQL statements
+#logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
+
+# uncomment to show SQL statements and result sets
+#logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
+
+
+from elementtree import ElementTree
+from elementtree.ElementTree import Element, SubElement
+
+meta = MetaData()
+meta.engine = 'sqlite://'
+
+################################# PART II - Table Metadata ###########################################
+
+# stores a top level record of an XML document.
+documents = Table('documents', meta,
+ Column('document_id', Integer, primary_key=True),
+ Column('filename', String(30), unique=True),
+ Column('element_id', Integer, ForeignKey('elements.element_id'))
+)
+
+# stores XML nodes in an adjacency list model. This corresponds to
+# Element and SubElement objects.
+elements = Table('elements', meta,
+ Column('element_id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('elements.element_id')),
+ Column('tag', Unicode(30), nullable=False),
+ Column('text', Unicode),
+ Column('tail', Unicode)
+ )
+
+# stores attributes. This corresponds to the dictionary of attributes
+# stored by an Element or SubElement.
+attributes = Table('attributes', meta,
+ Column('element_id', Integer, ForeignKey('elements.element_id'), primary_key=True),
+ Column('name', Unicode(100), nullable=False, primary_key=True),
+ Column('value', Unicode(255)))
+
+meta.create_all()
+
+#################################### PART III - Model #############################################
+
+# our document class. contains a string name,
+# and the ElementTree root element.
+class Document(object):
+ def __init__(self, name, element):
+ self.filename = name
+ self.element = element
+
+ def __str__(self):
+ buf = StringIO.StringIO()
+ self.element.write(buf)
+ return buf.getvalue()
+
+#################################### PART IV - Persistence Mapping ###################################
+
+# Node class. a non-public class which will represent
+# the DB-persisted Element/SubElement object. We cannot create mappers for
+# ElementTree elements directly because they are at the very least not new-style
+# classes, and also may be backed by native implementations.
+# so here we construct an adapter.
+class _Node(object):
+ pass
+
+# Attribute class. also internal, this will represent the key/value attributes stored for
+# a particular Node.
+class _Attribute(object):
+ def __init__(self, name, value):
+ self.name = name
+ self.value = value
+
+# setup mappers. Document will eagerly load a list of _Node objects.
+mapper(Document, documents, properties={
+ '_root':relation(_Node, lazy=False, cascade="all")
+})
+
+mapper(_Node, elements, properties={
+ 'children':relation(_Node, cascade="all"),
+ 'attributes':relation(_Attribute, lazy=False, cascade="all, delete-orphan"), # eagerly load attributes
+})
+
+mapper(_Attribute, attributes)
+
+# define marshalling functions that convert from _Node/_Attribute to/from ElementTree objects.
+# this will set the ElementTree element as "document._element", and append the root _Node
+# object to the "_root" mapped collection.
+class ElementTreeMarshal(object):
+ def __get__(self, document, owner):
+ if document is None:
+ return self
+
+ if hasattr(document, '_element'):
+ return document._element
+
+ def traverse(node, parent=None):
+ if parent is not None:
+ elem = ElementTree.SubElement(parent, node.tag)
+ else:
+ elem = ElementTree.Element(node.tag)
+ elem.text = node.text
+ elem.tail = node.tail
+ for attr in node.attributes:
+ elem.attrib[attr.name] = attr.value
+ for child in node.children:
+ traverse(child, parent=elem)
+ return elem
+
+ document._element = ElementTree.ElementTree(traverse(document._root))
+ return document._element
+
+ def __set__(self, document, element):
+ def traverse(node):
+ n = _Node()
+ n.tag = node.tag
+ n.text = node.text
+ n.tail = node.tail
+ n.children = [traverse(n2) for n2 in node]
+ n.attributes = [_Attribute(k, v) for k, v in node.attrib.iteritems()]
+ return n
+
+ document._root = traverse(element.getroot())
+ document._element = element
+
+ def __delete__(self, document):
+ del document._element
+ document._root = []
+
+# override Document's "element" attribute with the marshaller.
+Document.element = ElementTreeMarshal()
+
+########################################### PART V - Basic Persistence Example ############################
+
+line = "\n--------------------------------------------------------"
+
+# save to DB
+session = create_session()
+
+# get ElementTree documents
+for file in ('test.xml', 'test2.xml', 'test3.xml'):
+ filename = os.path.join(os.path.dirname(sys.argv[0]), file)
+ doc = ElementTree.parse(filename)
+ session.save(Document(file, doc))
+
+print "\nSaving three documents...", line
+session.flush()
+print "Done."
+
+# clear session (to illustrate a full load), restore
+session.clear()
+
+print "\nFull text of document 'text.xml':", line
+document = session.query(Document).filter_by(filename="test.xml").first()
+
+print document
+
+############################################ PART VI - Searching for Paths #######################################
+
+# manually search for a document which contains "/somefile/header/field1:hi"
+print "\nManual search for /somefile/header/field1=='hi':", line
+n1 = elements.alias('n1')
+n2 = elements.alias('n2')
+n3 = elements.alias('n3')
+j = documents.join(n1).join(n2, n1.c.element_id==n2.c.parent_id).join(n3, n2.c.element_id==n3.c.parent_id)
+d = session.query(Document).select_from(j).filter(n1.c.tag=='somefile').filter(n2.c.tag=='header').filter(and_(n3.c.tag=='field1', n3.c.text=='hi')).one()
+print d
+
+# generalize the above approach into an extremely impoverished xpath function:
+def find_document(path, compareto):
+ j = documents
+ prev_elements = None
+ query = session.query(Document)
+ for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)):
+ (token, attrname, attrvalue) = match.group(1, 2, 3)
+ a = elements.alias("n%d" % i)
+ query = query.filter(a.c.tag==token)
+ if attrname:
+ attr_alias = attributes.alias('a%d' % i)
+ if attrvalue:
+ query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname, attr_alias.c.value==attrvalue))
+ else:
+ query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname))
+ if prev_elements is not None:
+ j = j.join(a, prev_elements.c.element_id==a.c.parent_id)
+ else:
+ j = j.join(a)
+ prev_elements = a
+ return query.options(lazyload('_root')).select_from(j).filter(prev_elements.c.text==compareto).all()
+
+for path, compareto in (
+ ('/somefile/header/field1', 'hi'),
+ ('/somefile/field1', 'hi'),
+ ('/somefile/header/field2', 'there'),
+ ('/somefile/header/field2[@attr=foo]', 'there')
+ ):
+ print "\nDocuments containing '%s=%s':" % (path, compareto), line
+ print [d.filename for d in find_document(path, compareto)]
+
diff --git a/examples/elementtree/optimized_al.py b/examples/elementtree/optimized_al.py
new file mode 100644
index 000000000..17b6489de
--- /dev/null
+++ b/examples/elementtree/optimized_al.py
@@ -0,0 +1,224 @@
+"""This script duplicates adjacency_list.py, but optimizes the loading
+of XML nodes to be based on a "flattened" datamodel. Any number of XML documents,
+each of arbitrary complexity, can be loaded in their entirety via a single query
+which joins on only three tables.
+
+"""
+
+################################# PART I - Imports/Coniguration ###########################################
+from sqlalchemy import *
+from sqlalchemy.orm import *
+
+import sys, os, StringIO, re
+
+import logging
+logging.basicConfig()
+
+# uncomment to show SQL statements
+#logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
+
+# uncomment to show SQL statements and result sets
+logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
+
+
+from elementtree import ElementTree
+from elementtree.ElementTree import Element, SubElement
+
+meta = MetaData()
+meta.engine = 'sqlite://'
+
+################################# PART II - Table Metadata ###########################################
+
+# stores a top level record of an XML document.
+documents = Table('documents', meta,
+ Column('document_id', Integer, primary_key=True),
+ Column('filename', String(30), unique=True),
+)
+
+# stores XML nodes in an adjacency list model. This corresponds to
+# Element and SubElement objects.
+elements = Table('elements', meta,
+ Column('element_id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('elements.element_id')),
+ Column('document_id', Integer, ForeignKey('documents.document_id')),
+ Column('tag', Unicode(30), nullable=False),
+ Column('text', Unicode),
+ Column('tail', Unicode)
+ )
+
+# stores attributes. This corresponds to the dictionary of attributes
+# stored by an Element or SubElement.
+attributes = Table('attributes', meta,
+ Column('element_id', Integer, ForeignKey('elements.element_id'), primary_key=True),
+ Column('name', Unicode(100), nullable=False, primary_key=True),
+ Column('value', Unicode(255)))
+
+meta.create_all()
+
+#################################### PART III - Model #############################################
+
+# our document class. contains a string name,
+# and the ElementTree root element.
+class Document(object):
+ def __init__(self, name, element):
+ self.filename = name
+ self.element = element
+
+ def __str__(self):
+ buf = StringIO.StringIO()
+ self.element.write(buf)
+ return buf.getvalue()
+
+#################################### PART IV - Persistence Mapping ###################################
+
+# Node class. a non-public class which will represent
+# the DB-persisted Element/SubElement object. We cannot create mappers for
+# ElementTree elements directly because they are at the very least not new-style
+# classes, and also may be backed by native implementations.
+# so here we construct an adapter.
+class _Node(object):
+ pass
+
+# Attribute class. also internal, this will represent the key/value attributes stored for
+# a particular Node.
+class _Attribute(object):
+ def __init__(self, name, value):
+ self.name = name
+ self.value = value
+
+# setup mappers. Document will eagerly load a list of _Node objects.
+# they will be ordered in primary key/insert order, so that we can reconstruct
+# an ElementTree structure from the list.
+mapper(Document, documents, properties={
+ '_nodes':relation(_Node, lazy=False, cascade="all, delete-orphan")
+})
+
+# the _Node objects change the way they load so that a list of _Nodes will organize
+# themselves hierarchically using the HierarchicalLoader. this depends on the ordering of
+# nodes being hierarchical as well; relation() always applies at least ROWID/primary key
+# ordering to rows which will suffice.
+mapper(_Node, elements, properties={
+ 'children':relation(_Node, lazy=None), # doesnt load; used only for the save relationship
+ 'attributes':relation(_Attribute, lazy=False, cascade="all, delete-orphan"), # eagerly load attributes
+})
+
+mapper(_Attribute, attributes)
+
+# define marshalling functions that convert from _Node/_Attribute to/from ElementTree objects.
+# this will set the ElementTree element as "document._element", and append the root _Node
+# object to the "_nodes" mapped collection.
+class ElementTreeMarshal(object):
+ def __get__(self, document, owner):
+ if document is None:
+ return self
+
+ if hasattr(document, '_element'):
+ return document._element
+
+ nodes = {}
+ root = None
+ for node in document._nodes:
+ if node.parent_id is not None:
+ parent = nodes[node.parent_id]
+ elem = ElementTree.SubElement(parent, node.tag)
+ nodes[node.element_id] = elem
+ else:
+ parent = None
+ elem = root = ElementTree.Element(node.tag)
+ nodes[node.element_id] = root
+ for attr in node.attributes:
+ elem.attrib[attr.name] = attr.value
+ elem.text = node.text
+ elem.tail = node.tail
+
+ document._element = ElementTree.ElementTree(root)
+ return document._element
+
+ def __set__(self, document, element):
+ def traverse(node):
+ n = _Node()
+ n.tag = node.tag
+ n.text = node.text
+ n.tail = node.tail
+ document._nodes.append(n)
+ n.children = [traverse(n2) for n2 in node]
+ n.attributes = [_Attribute(k, v) for k, v in node.attrib.iteritems()]
+ return n
+
+ traverse(element.getroot())
+ document._element = element
+
+ def __delete__(self, document):
+ del document._element
+ document._nodes = []
+
+# override Document's "element" attribute with the marshaller.
+Document.element = ElementTreeMarshal()
+
+########################################### PART V - Basic Persistence Example ############################
+
+line = "\n--------------------------------------------------------"
+
+# save to DB
+session = create_session()
+
+# get ElementTree documents
+for file in ('test.xml', 'test2.xml', 'test3.xml'):
+ filename = os.path.join(os.path.dirname(sys.argv[0]), file)
+ doc = ElementTree.parse(filename)
+ session.save(Document(file, doc))
+
+print "\nSaving three documents...", line
+session.flush()
+print "Done."
+
+# clear session (to illustrate a full load), restore
+session.clear()
+
+print "\nFull text of document 'text.xml':", line
+document = session.query(Document).filter_by(filename="test.xml").first()
+
+print document
+
+############################################ PART VI - Searching for Paths #######################################
+
+# manually search for a document which contains "/somefile/header/field1:hi"
+print "\nManual search for /somefile/header/field1=='hi':", line
+n1 = elements.alias('n1')
+n2 = elements.alias('n2')
+n3 = elements.alias('n3')
+j = documents.join(n1).join(n2, n1.c.element_id==n2.c.parent_id).join(n3, n2.c.element_id==n3.c.parent_id)
+d = session.query(Document).select_from(j).filter(n1.c.tag=='somefile').filter(n2.c.tag=='header').filter(and_(n3.c.tag=='field1', n3.c.text=='hi')).one()
+print d
+
+# generalize the above approach into an extremely impoverished xpath function:
+def find_document(path, compareto):
+ j = documents
+ prev_elements = None
+ query = session.query(Document)
+ for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)):
+ (token, attrname, attrvalue) = match.group(1, 2, 3)
+ a = elements.alias("n%d" % i)
+ query = query.filter(a.c.tag==token)
+ if attrname:
+ attr_alias = attributes.alias('a%d' % i)
+ if attrvalue:
+ query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname, attr_alias.c.value==attrvalue))
+ else:
+ query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname))
+ if prev_elements is not None:
+ j = j.join(a, prev_elements.c.element_id==a.c.parent_id)
+ else:
+ j = j.join(a)
+ prev_elements = a
+ return query.options(lazyload('_nodes')).select_from(j).filter(prev_elements.c.text==compareto).all()
+
+for path, compareto in (
+ ('/somefile/header/field1', 'hi'),
+ ('/somefile/field1', 'hi'),
+ ('/somefile/header/field2', 'there'),
+ ('/somefile/header/field2[@attr=foo]', 'there')
+ ):
+ print "\nDocuments containing '%s=%s':" % (path, compareto), line
+ print [d.filename for d in find_document(path, compareto)]
+
diff --git a/examples/elementtree/pickle.py b/examples/elementtree/pickle.py
new file mode 100644
index 000000000..443ca85c3
--- /dev/null
+++ b/examples/elementtree/pickle.py
@@ -0,0 +1,65 @@
+"""illustrates a quick and dirty way to persist an XML document expressed using ElementTree and pickle.
+
+This is a trivial example using PickleType to marshal/unmarshal the ElementTree
+document into a binary column. Compare to explicit.py which stores the individual components of the ElementTree
+structure in distinct rows using two additional mapped entities. Note that the usage of both
+styles of persistence are identical, as is the structure of the main Document class.
+"""
+
+from sqlalchemy import *
+from sqlalchemy.orm import *
+
+import sys, os
+
+import logging
+logging.basicConfig()
+
+# uncomment to show SQL statements
+#logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
+
+# uncomment to show SQL statements and result sets
+#logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
+
+from elementtree import ElementTree
+
+meta = MetaData()
+meta.engine = 'sqlite://'
+
+# stores a top level record of an XML document.
+# the "element" column will store the ElementTree document as a BLOB.
+documents = Table('documents', meta,
+ Column('document_id', Integer, primary_key=True),
+ Column('filename', String(30), unique=True),
+ Column('element', PickleType)
+)
+
+meta.create_all()
+
+# our document class. contains a string name,
+# and the ElementTree root element.
+class Document(object):
+ def __init__(self, name, element):
+ self.filename = name
+ self.element = element
+
+# setup mapper.
+mapper(Document, documents)
+
+###### time to test ! #########
+
+# get ElementTree document
+filename = os.path.join(os.path.dirname(sys.argv[0]), "test.xml")
+doc = ElementTree.parse(filename)
+
+# save to DB
+session = create_session()
+session.save(Document("test.xml", doc))
+session.flush()
+
+# clear session (to illustrate a full load), restore
+session.clear()
+document = session.query(Document).filter_by(filename="test.xml").first()
+
+# print
+document.element.write(sys.stdout)
+
diff --git a/examples/elementtree/test.xml b/examples/elementtree/test.xml
new file mode 100644
index 000000000..edb44ccc2
--- /dev/null
+++ b/examples/elementtree/test.xml
@@ -0,0 +1,9 @@
+<somefile>
+ This is somefile.
+ <header name="foo" value="bar" hoho="lala">
+ <field1>hi</field1>
+ <field2>there</field2>
+ Some additional text within the header.
+ </header>
+ Some more text within somefile.
+</somefile> \ No newline at end of file
diff --git a/examples/elementtree/test2.xml b/examples/elementtree/test2.xml
new file mode 100644
index 000000000..69d3167a8
--- /dev/null
+++ b/examples/elementtree/test2.xml
@@ -0,0 +1,4 @@
+<somefile>
+ <field1>hi</field1>
+ <field2>there</field2>
+</somefile> \ No newline at end of file
diff --git a/examples/elementtree/test3.xml b/examples/elementtree/test3.xml
new file mode 100644
index 000000000..6a7a2343e
--- /dev/null
+++ b/examples/elementtree/test3.xml
@@ -0,0 +1,7 @@
+<somefile>
+ test3
+ <header name="aheader" value="bar" hoho="lala">
+ <field1>one</field1>
+ <field2 attr='foo'>there</field2>
+ </header>
+</somefile> \ No newline at end of file
diff --git a/examples/pickle/custom_pickler.py b/examples/pickle/custom_pickler.py
index 4b259c1f8..b45e16e7c 100644
--- a/examples/pickle/custom_pickler.py
+++ b/examples/pickle/custom_pickler.py
@@ -6,7 +6,8 @@ from cStringIO import StringIO
from pickle import Pickler, Unpickler
import threading
-meta = MetaData('sqlite://', echo=True)
+meta = MetaData('sqlite://')
+meta.bind.echo = True
class MyExt(MapperExtension):
def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
@@ -25,7 +26,7 @@ class MyPickler(object):
def persistent_id(self, obj):
if getattr(obj, "id", None) is None:
sess = MyPickler.sessions.current
- newsess = create_session(bind_to=sess.connection(class_mapper(Bar)))
+ newsess = create_session(bind=sess.connection(class_mapper(Bar)))
newsess.save(obj)
newsess.flush()
key = "%s:%s" % (type(obj).__name__, obj.id)
diff --git a/examples/poly_assoc/poly_assoc.py b/examples/poly_assoc/poly_assoc.py
index a2ac6140f..c13ffbfa1 100644
--- a/examples/poly_assoc/poly_assoc.py
+++ b/examples/poly_assoc/poly_assoc.py
@@ -21,8 +21,9 @@ the associated target object from those which associate with it.
"""
from sqlalchemy import *
+from sqlalchemy.orm import *
-metadata = MetaData('sqlite://', echo=False)
+metadata = MetaData('sqlite://')
#######
# addresses table, class, 'addressable interface'.
diff --git a/examples/poly_assoc/poly_assoc_fk.py b/examples/poly_assoc/poly_assoc_fk.py
index f3cedac72..22ee50009 100644
--- a/examples/poly_assoc/poly_assoc_fk.py
+++ b/examples/poly_assoc/poly_assoc_fk.py
@@ -20,8 +20,9 @@ poly_assoc_generic.py.
"""
from sqlalchemy import *
+from sqlalchemy.orm import *
-metadata = MetaData('sqlite://', echo=False)
+metadata = MetaData('sqlite://')
#######
# addresses table, class, 'addressable interface'.
diff --git a/examples/poly_assoc/poly_assoc_generic.py b/examples/poly_assoc/poly_assoc_generic.py
index 9cc7321db..4fca31019 100644
--- a/examples/poly_assoc/poly_assoc_generic.py
+++ b/examples/poly_assoc/poly_assoc_generic.py
@@ -7,8 +7,9 @@ function "association" which creates a new polymorphic association
"""
from sqlalchemy import *
+from sqlalchemy.orm import *
-metadata = MetaData('sqlite://', echo=False)
+metadata = MetaData('sqlite://')
def association(cls, table):
"""create an association 'interface'."""
diff --git a/examples/polymorph/concrete.py b/examples/polymorph/concrete.py
index 593d3f480..5f12e9a3d 100644
--- a/examples/polymorph/concrete.py
+++ b/examples/polymorph/concrete.py
@@ -1,4 +1,5 @@
from sqlalchemy import *
+from sqlalchemy.orm import *
metadata = MetaData()
@@ -49,7 +50,7 @@ manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concr
engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer')
-session = create_session(bind_to=engine)
+session = create_session(bind=engine)
m1 = Manager("pointy haired boss", "manager1")
e1 = Engineer("wally", "engineer1")
diff --git a/examples/polymorph/polymorph.py b/examples/polymorph/polymorph.py
index 231a9d8e4..4f3aeb7d2 100644
--- a/examples/polymorph/polymorph.py
+++ b/examples/polymorph/polymorph.py
@@ -1,10 +1,11 @@
from sqlalchemy import *
+from sqlalchemy.orm import *
import sets
-# this example illustrates a polymorphic load of two classes, where each class has a very
-# different set of properties
+# this example illustrates a polymorphic load of two classes
-metadata = MetaData('sqlite://', echo=True)
+metadata = MetaData('sqlite://')
+metadata.bind.echo = True
# a table to store companies
companies = Table('companies', metadata,
diff --git a/examples/polymorph/single.py b/examples/polymorph/single.py
index dcdb3c890..dfc426416 100644
--- a/examples/polymorph/single.py
+++ b/examples/polymorph/single.py
@@ -1,6 +1,8 @@
from sqlalchemy import *
+from sqlalchemy.orm import *
-metadata = MetaData('sqlite://', echo='debug')
+metadata = MetaData('sqlite://')
+metadata.bind.echo = 'debug'
# a table to store companies
companies = Table('companies', metadata,
diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py
new file mode 100644
index 000000000..e95b978ae
--- /dev/null
+++ b/examples/sharding/attribute_shard.py
@@ -0,0 +1,194 @@
+"""a basic example of using the SQLAlchemy Sharding API.
+Sharding refers to horizontally scaling data across multiple
+databases.
+
+In this example, four sqlite databases will store information about
+weather data on a database-per-continent basis.
+
+To set up a sharding system, you need:
+ 1. multiple databases, each assined a 'shard id'
+ 2. a function which can return a single shard id, given an instance
+ to be saved; this is called "shard_chooser"
+ 3. a function which can return a list of shard ids which apply to a particular
+ instance identifier; this is called "id_chooser". If it returns all shard ids,
+ all shards will be searched.
+ 4. a function which can return a list of shard ids to try, given a particular
+ Query ("query_chooser"). If it returns all shard ids, all shards will be
+ queried and the results joined together.
+"""
+
+# step 1. imports
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.orm.shard import ShardedSession
+from sqlalchemy.sql import ColumnOperators
+import datetime, operator
+
+# step 2. databases
+echo = True
+db1 = create_engine('sqlite:///shard1.db', echo=echo)
+db2 = create_engine('sqlite:///shard2.db', echo=echo)
+db3 = create_engine('sqlite:///shard3.db', echo=echo)
+db4 = create_engine('sqlite:///shard4.db', echo=echo)
+
+
+# step 3. create session function. this binds the shard ids
+# to databases within a ShardedSession and returns it.
+def create_session():
+ s = ShardedSession(shard_chooser, id_chooser, query_chooser)
+ s.bind_shard('north_america', db1)
+ s.bind_shard('asia', db2)
+ s.bind_shard('europe', db3)
+ s.bind_shard('south_america', db4)
+ return s
+
+# step 4. table setup.
+meta = MetaData()
+
+# we need a way to create identifiers which are unique across all
+# databases. one easy way would be to just use a composite primary key, where one
+# value is the shard id. but here, we'll show something more "generic", an
+# id generation function. we'll use a simplistic "id table" stored in database
+# #1. Any other method will do just as well; UUID, hilo, application-specific, etc.
+
+ids = Table('ids', meta,
+ Column('nextid', Integer, nullable=False))
+
+def id_generator(ctx):
+ # in reality, might want to use a separate transaction for this.
+ c = db1.connect()
+ nextid = c.execute(ids.select(for_update=True)).scalar()
+ c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1}))
+ return nextid
+
+# table setup. we'll store a lead table of continents/cities,
+# and a secondary table storing locations.
+# a particular row will be placed in the database whose shard id corresponds to the
+# 'continent'. in this setup, secondary rows in 'weather_reports' will
+# be placed in the same DB as that of the parent, but this can be changed
+# if you're willing to write more complex sharding functions.
+
+weather_locations = Table("weather_locations", meta,
+ Column('id', Integer, primary_key=True, default=id_generator),
+ Column('continent', String(30), nullable=False),
+ Column('city', String(50), nullable=False)
+ )
+
+weather_reports = Table("weather_reports", meta,
+ Column('id', Integer, primary_key=True),
+ Column('location_id', Integer, ForeignKey('weather_locations.id')),
+ Column('temperature', Float),
+ Column('report_time', DateTime, default=datetime.datetime.now),
+)
+
+# create tables
+for db in (db1, db2, db3, db4):
+ meta.drop_all(db)
+ meta.create_all(db)
+
+# establish initial "id" in db1
+db1.execute(ids.insert(), nextid=1)
+
+
+# step 5. define sharding functions.
+
+# we'll use a straight mapping of a particular set of "country"
+# attributes to shard id.
+shard_lookup = {
+ 'North America':'north_america',
+ 'Asia':'asia',
+ 'Europe':'europe',
+ 'South America':'south_america'
+}
+
+# shard_chooser - looks at the given instance and returns a shard id
+# note that we need to define conditions for
+# the WeatherLocation class, as well as our secondary Report class which will
+# point back to its WeatherLocation via its 'location' attribute.
+def shard_chooser(mapper, instance):
+ if isinstance(instance, WeatherLocation):
+ return shard_lookup[instance.continent]
+ else:
+ return shard_chooser(mapper, instance.location)
+
+# id_chooser. given a primary key, returns a list of shards
+# to search. here, we don't have any particular information from a
+# pk so we just return all shard ids. often, youd want to do some
+# kind of round-robin strategy here so that requests are evenly
+# distributed among DBs
+def id_chooser(ident):
+ return ['north_america', 'asia', 'europe', 'south_america']
+
+# query_chooser. this also returns a list of shard ids, which can
+# just be all of them. but here we'll search into the Query in order
+# to try to narrow down the list of shards to query.
+def query_chooser(query):
+ ids = []
+
+ # here we will traverse through the query's criterion, searching
+ # for SQL constructs. we'll grab continent names as we find them
+ # and convert to shard ids
+ class FindContinent(sql.ClauseVisitor):
+ def visit_binary(self, binary):
+ if binary.left is weather_locations.c.continent:
+ if binary.operator == operator.eq:
+ ids.append(shard_lookup[binary.right.value])
+ elif binary.operator == ColumnOperators.in_op:
+ for bind in binary.right.clauses:
+ ids.append(shard_lookup[bind.value])
+
+ FindContinent().traverse(query._criterion)
+ if len(ids) == 0:
+ return ['north_america', 'asia', 'europe', 'south_america']
+ else:
+ return ids
+
+# step 6. mapped classes.
+class WeatherLocation(object):
+ def __init__(self, continent, city):
+ self.continent = continent
+ self.city = city
+
+class Report(object):
+ def __init__(self, temperature):
+ self.temperature = temperature
+
+# step 7. mappers
+mapper(WeatherLocation, weather_locations, properties={
+ 'reports':relation(Report, backref='location')
+})
+
+mapper(Report, weather_reports)
+
+
+# save and load objects!
+
+tokyo = WeatherLocation('Asia', 'Tokyo')
+newyork = WeatherLocation('North America', 'New York')
+toronto = WeatherLocation('North America', 'Toronto')
+london = WeatherLocation('Europe', 'London')
+dublin = WeatherLocation('Europe', 'Dublin')
+brasilia = WeatherLocation('South America', 'Brasila')
+quito = WeatherLocation('South America', 'Quito')
+
+tokyo.reports.append(Report(80.0))
+newyork.reports.append(Report(75))
+quito.reports.append(Report(85))
+
+sess = create_session()
+for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
+ sess.save(c)
+sess.flush()
+
+sess.clear()
+
+t = sess.query(WeatherLocation).get(tokyo.id)
+assert t.city == tokyo.city
+assert t.reports[0].temperature == 80.0
+
+north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America')
+assert [c.city for c in north_american_cities] == ['New York', 'Toronto']
+
+asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_('Europe', 'Asia'))
+assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin'])
+
diff --git a/examples/vertical/vertical.py b/examples/vertical/vertical.py
index e9fff9163..e3b48c336 100644
--- a/examples/vertical/vertical.py
+++ b/examples/vertical/vertical.py
@@ -3,9 +3,12 @@ represented in distinct database rows. This allows objects to be created with d
fields that are all persisted in a normalized fashion."""
from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.orm.collections import mapped_collection
import datetime
-e = MetaData('sqlite://', echo=True)
+e = MetaData('sqlite://')
+e.bind.echo = True
# this table represents Entity objects. each Entity gets a row in this table,
# with a primary key and a title.
@@ -37,14 +40,6 @@ entity_values = Table('entity_values', e,
e.create_all()
-class EntityDict(dict):
- """this is a dictionary that implements an append() and an __iter__ method.
- such a dictionary can be used with SQLAlchemy list-based attributes."""
- def append(self, entityvalue):
- self[entityvalue.field.name] = entityvalue
- def __iter__(self):
- return iter(self.values())
-
class Entity(object):
"""represents an Entity. The __getattr__ method is overridden to search the
object's _entities dictionary for the appropriate value, and the __setattribute__
@@ -123,7 +118,7 @@ mapper(
)
mapper(Entity, entities, properties = {
- '_entities' : relation(EntityValue, lazy=False, cascade='all', collection_class=EntityDict)
+ '_entities' : relation(EntityValue, lazy=False, cascade='all', collection_class=mapped_collection(lambda entityvalue: entityvalue.field.name))
})
# create two entities. the objects can be used about as regularly as
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index ad2615131..6e95fd7e1 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -7,10 +7,8 @@
from sqlalchemy.types import *
from sqlalchemy.sql import *
from sqlalchemy.schema import *
-from sqlalchemy.orm import *
from sqlalchemy.engine import create_engine
-from sqlalchemy.schema import default_metadata
def __figure_version():
try:
@@ -25,8 +23,6 @@ def __figure_version():
return '(not installed)'
except:
return '(not installed)'
-
+
__version__ = __figure_version()
-
-def global_connect(*args, **kwargs):
- default_metadata.connect(*args, **kwargs)
+
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 9994d5288..22227d56a 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -10,9 +10,10 @@ Contains default implementations for the abstract objects in the sql
module.
"""
-from sqlalchemy import schema, sql, engine, util, sql_util, exceptions
+import string, re, sets, operator
+
+from sqlalchemy import schema, sql, engine, util, exceptions
from sqlalchemy.engine import default
-import string, re, sets, weakref, random
ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
@@ -40,6 +41,41 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array',
LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$')
ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE)
+
+OPERATORS = {
+ operator.and_ : 'AND',
+ operator.or_ : 'OR',
+ operator.inv : 'NOT',
+ operator.add : '+',
+ operator.mul : '*',
+ operator.sub : '-',
+ operator.div : '/',
+ operator.mod : '%',
+ operator.truediv : '/',
+ operator.lt : '<',
+ operator.le : '<=',
+ operator.ne : '!=',
+ operator.gt : '>',
+ operator.ge : '>=',
+ operator.eq : '=',
+ sql.ColumnOperators.concat_op : '||',
+ sql.ColumnOperators.like_op : 'LIKE',
+ sql.ColumnOperators.notlike_op : 'NOT LIKE',
+ sql.ColumnOperators.ilike_op : 'ILIKE',
+ sql.ColumnOperators.notilike_op : 'NOT ILIKE',
+ sql.ColumnOperators.between_op : 'BETWEEN',
+ sql.ColumnOperators.in_op : 'IN',
+ sql.ColumnOperators.notin_op : 'NOT IN',
+ sql.ColumnOperators.comma_op : ', ',
+ sql.Operators.from_ : 'FROM',
+ sql.Operators.as_ : 'AS',
+ sql.Operators.exists : 'EXISTS',
+ sql.Operators.is_ : 'IS',
+ sql.Operators.isnot : 'IS NOT'
+}
+
class ANSIDialect(default.DefaultDialect):
def __init__(self, cache_identifiers=True, **kwargs):
super(ANSIDialect,self).__init__(**kwargs)
@@ -66,14 +102,16 @@ class ANSIDialect(default.DefaultDialect):
"""
return ANSIIdentifierPreparer(self)
-class ANSICompiler(sql.Compiled):
+class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
"""Default implementation of Compiled.
Compiles ClauseElements into ANSI-compliant SQL strings.
"""
- __traverse_options__ = {'column_collections':False}
+ __traverse_options__ = {'column_collections':False, 'entry':True}
+ operators = OPERATORS
+
def __init__(self, dialect, statement, parameters=None, **kwargs):
"""Construct a new ``ANSICompiler`` object.
@@ -92,7 +130,7 @@ class ANSICompiler(sql.Compiled):
correspond to the keys present in the parameters.
"""
- sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+ super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
# if we are insert/update. set to true when we visit an INSERT or UPDATE
self.isinsert = self.isupdate = False
@@ -104,21 +142,6 @@ class ANSICompiler(sql.Compiled):
# actually present in the generated SQL
self.bind_names = {}
- # a dictionary which stores the string representation for every ClauseElement
- # processed by this compiler.
- self.strings = {}
-
- # a dictionary which stores the string representation for ClauseElements
- # processed by this compiler, which are to be used in the FROM clause
- # of a select. items are often placed in "froms" as well as "strings"
- # and sometimes with different representations.
- self.froms = {}
-
- # slightly hacky. maps FROM clauses to WHERE clauses, and used in select
- # generation to modify the WHERE clause of the select. currently a hack
- # used by the oracle module.
- self.wheres = {}
-
# when the compiler visits a SELECT statement, the clause object is appended
# to this stack. various visit operations will check this stack to determine
# additional choices (TODO: it seems to be all typemap stuff. shouldnt this only
@@ -137,12 +160,6 @@ class ANSICompiler(sql.Compiled):
# for aliases
self.generated_ids = {}
- # True if this compiled represents an INSERT
- self.isinsert = False
-
- # True if this compiled represents an UPDATE
- self.isupdate = False
-
# default formatting style for bind parameters
self.bindtemplate = ":%s"
@@ -158,64 +175,76 @@ class ANSICompiler(sql.Compiled):
# an ANSIIdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
-
+
+ # a dictionary containing attributes about all select()
+ # elements located within the clause, regarding which are subqueries, which are
+ # selected from, and which elements should be correlated to an enclosing select.
+ # used mostly to determine the list of FROM elements for each select statement, as well
+ # as some dialect-specific rules regarding subqueries.
+ self.correlate_state = {}
+
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
# default-value logic in the Dialect knows not to fire off column defaults
# and also knows postfetching will be needed to get the values represented by these
# parameters.
self.inline_params = None
-
+
def after_compile(self):
# this re will search for params like :param
# it has a negative lookbehind for an extra ':' so that it doesnt match
# postgres '::text' tokens
- match = re.compile(r'(?<!:):([\w_]+)', re.UNICODE)
+ text = self.string
+ if ':' not in text:
+ return
+
if self.paramstyle=='pyformat':
- self.strings[self.statement] = match.sub(lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
+ text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text)
elif self.positional:
- params = match.finditer(self.strings[self.statement])
+ params = BIND_PARAMS.finditer(text)
for p in params:
self.positiontup.append(p.group(1))
if self.paramstyle=='qmark':
- self.strings[self.statement] = match.sub('?', self.strings[self.statement])
+ text = BIND_PARAMS.sub('?', text)
elif self.paramstyle=='format':
- self.strings[self.statement] = match.sub('%s', self.strings[self.statement])
+ text = BIND_PARAMS.sub('%s', text)
elif self.paramstyle=='numeric':
i = [0]
def getnum(x):
i[0] += 1
return str(i[0])
- self.strings[self.statement] = match.sub(getnum, self.strings[self.statement])
-
- def get_from_text(self, obj):
- return self.froms.get(obj, None)
-
- def get_str(self, obj):
- return self.strings[obj]
-
+ text = BIND_PARAMS.sub(getnum, text)
+ # un-escape any \:params
+ text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
+ self.string = text
+
+ def compile(self):
+ self.string = self.process(self.statement)
+ self.after_compile()
+
+ def process(self, obj, **kwargs):
+ return self.traverse_single(obj, **kwargs)
+
+ def is_subquery(self, select):
+ return self.correlate_state[select].get('is_subquery', False)
+
def get_whereclause(self, obj):
- return self.wheres.get(obj, None)
+ """given a FROM clause, return an additional WHERE condition that should be
+ applied to a SELECT.
+
+ Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN
+ constructs in non-ansi mode.
+ """
+
+ return None
def construct_params(self, params):
- """Return a structure of bind parameters for this compiled object.
-
- This includes bind parameters that might be compiled in via
- the `values` argument of an ``Insert`` or ``Update`` statement
- object, and also the given `**params`. The keys inside of
- `**params` can be any key that matches the
- ``BindParameterClause`` objects compiled within this object.
-
- The output is dependent on the paramstyle of the DBAPI being
- used; if a named style, the return result will be a dictionary
- with keynames matching the compiled statement. If a
- positional style, the output will be a list, with an iterator
- that will return parameter values in an order corresponding to
- the bind positions in the compiled statement.
-
- For an executemany style of call, this method should be called
- for each element in the list of parameter groups that will
- ultimately be executed.
+ """Return a sql.ClauseParameters object.
+
+ Combines the given bind parameter dictionary (string keys to object values)
+ with the _BindParamClause objects stored within this Compiled object
+ to produce a ClauseParameters structure, representing the bind arguments
+ for a single statement execution, or one element of an executemany execution.
"""
if self.parameters is not None:
@@ -225,7 +254,7 @@ class ANSICompiler(sql.Compiled):
bindparams.update(params)
d = sql.ClauseParameters(self.dialect, self.positiontup)
for b in self.binds.values():
- name = self.bind_names.get(b, b.key)
+ name = self.bind_names[b]
d.set_parameter(b, b.value, name)
for key, value in bindparams.iteritems():
@@ -233,7 +262,7 @@ class ANSICompiler(sql.Compiled):
b = self.binds[key]
except KeyError:
continue
- name = self.bind_names.get(b, b.key)
+ name = self.bind_names[b]
d.set_parameter(b, value, name)
return d
@@ -246,8 +275,8 @@ class ANSICompiler(sql.Compiled):
return ""
- def visit_grouping(self, grouping):
- self.strings[grouping] = "(" + self.strings[grouping.elem] + ")"
+ def visit_grouping(self, grouping, **kwargs):
+ return "(" + self.process(grouping.elem) + ")"
def visit_label(self, label):
labelname = self._truncated_identifier("colident", label.name)
@@ -256,9 +285,10 @@ class ANSICompiler(sql.Compiled):
self.typemap.setdefault(labelname.lower(), label.obj.type)
if isinstance(label.obj, sql._ColumnClause):
self.column_labels[label.obj._label] = labelname
- self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname)
+ self.column_labels[label.name] = labelname
+ return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
- def visit_column(self, column):
+ def visit_column(self, column, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
# want to truncate a column identifier, if its mapped to the name of a
# physical column. but thats very hard to identify at this point, and
@@ -269,107 +299,110 @@ class ANSICompiler(sql.Compiled):
else:
name = column.name
+ if len(self.select_stack):
+ # if we are within a visit to a Select, set up the "typemap"
+ # for this column which is used to translate result set values
+ self.typemap.setdefault(name.lower(), column.type)
+ self.column_labels.setdefault(column._label, name.lower())
+
if column.table is None or not column.table.named_with_column():
- self.strings[column] = self.preparer.format_column(column, name=name)
+ return self.preparer.format_column(column, name=name)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name(column)
if n is not None:
- self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
+ return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
elif len(column.table.primary_key) != 0:
pk = list(column.table.primary_key)[0]
pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
- self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname)
+ return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
else:
- self.strings[column] = None
+ return None
else:
- self.strings[column] = self.preparer.format_column_with_table(column, column_name=name)
+ return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
- if len(self.select_stack):
- # if we are within a visit to a Select, set up the "typemap"
- # for this column which is used to translate result set values
- self.typemap.setdefault(name.lower(), column.type)
- self.column_labels.setdefault(column._label, name.lower())
- def visit_fromclause(self, fromclause):
- self.froms[fromclause] = fromclause.name
+ def visit_fromclause(self, fromclause, **kwargs):
+ return fromclause.name
- def visit_index(self, index):
- self.strings[index] = index.name
+ def visit_index(self, index, **kwargs):
+ return index.name
- def visit_typeclause(self, typeclause):
- self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
+ def visit_typeclause(self, typeclause, **kwargs):
+ return typeclause.type.dialect_impl(self.dialect).get_col_spec()
- def visit_textclause(self, textclause):
- self.strings[textclause] = textclause.text
- self.froms[textclause] = textclause.text
+ def visit_textclause(self, textclause, **kwargs):
+ for bind in textclause.bindparams.values():
+ self.process(bind)
if textclause.typemap is not None:
self.typemap.update(textclause.typemap)
+ return textclause.text
- def visit_null(self, null):
- self.strings[null] = 'NULL'
+ def visit_null(self, null, **kwargs):
+ return 'NULL'
- def visit_clauselist(self, list):
- sep = list.operator
- if sep == ',':
- sep = ', '
- elif sep is None or sep == " ":
+ def visit_clauselist(self, clauselist, **kwargs):
+ sep = clauselist.operator
+ if sep is None:
sep = " "
+ elif sep == sql.ColumnOperators.comma_op:
+ sep = ', '
else:
- sep = " " + sep + " "
- self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep)
+ sep = " " + self.operator_string(clauselist.operator) + " "
+ return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
- def visit_calculatedclause(self, clause):
- self.strings[clause] = self.get_str(clause.clause_expr)
+ def visit_calculatedclause(self, clause, **kwargs):
+ return self.process(clause.clause_expr)
- def visit_cast(self, cast):
+ def visit_cast(self, cast, **kwargs):
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
- self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause])
+ return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
- def visit_function(self, func):
+ def visit_function(self, func, **kwargs):
if len(self.select_stack):
self.typemap.setdefault(func.name, func.type)
if not self.apply_function_parens(func):
- self.strings[func] = ".".join(func.packagenames + [func.name])
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name])
else:
- self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr)
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
- def visit_compound_select(self, cs):
- text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
- group_by = self.get_str(cs.group_by_clause)
+ def visit_compound_select(self, cs, asfrom=False, **kwargs):
+ text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
+ group_by = self.process(cs._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
- text += self.visit_select_postclauses(cs)
- self.strings[cs] = text
- self.froms[cs] = "(" + text + ")"
+ text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
+
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_unary(self, unary):
- s = self.get_str(unary.element)
+ def visit_unary(self, unary, **kwargs):
+ s = self.process(unary.element)
if unary.operator:
- s = unary.operator + " " + s
+ s = self.operator_string(unary.operator) + " " + s
if unary.modifier:
s = s + " " + unary.modifier
- self.strings[unary] = s
+ return s
- def visit_binary(self, binary):
- result = self.get_str(binary.left)
- if binary.operator is not None:
- result += " " + self.binary_operator_string(binary)
- result += " " + self.get_str(binary.right)
- self.strings[binary] = result
-
- def binary_operator_string(self, binary):
- return binary.operator
+ def visit_binary(self, binary, **kwargs):
+ op = self.operator_string(binary.operator)
+ if callable(op):
+ return op(self.process(binary.left), self.process(binary.right))
+ else:
+ return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+
+ def operator_string(self, operator):
+ return self.operators.get(operator, str(operator))
- def visit_bindparam(self, bindparam):
+ def visit_bindparam(self, bindparam, **kwargs):
# apply truncation to the ultimate generated name
if bindparam.shortname != bindparam.key:
@@ -378,7 +411,6 @@ class ANSICompiler(sql.Compiled):
if bindparam.unique:
count = 1
key = bindparam.key
-
# redefine the generated name of the bind param in the case
# that we have multiple conflicting bind parameters.
while self.binds.setdefault(key, bindparam) is not bindparam:
@@ -386,164 +418,167 @@ class ANSICompiler(sql.Compiled):
key = bindparam.key + tag
count += 1
bindparam.key = key
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
else:
existing = self.binds.get(bindparam.key)
if existing is not None and existing.unique:
raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
self.binds[bindparam.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
bind_name = bindparam.key
- if len(bind_name) > self.dialect.max_identifier_length():
- bind_name = self._truncated_identifier("bindparam", bind_name)
- # add to bind_names for translation
- self.bind_names[bindparam] = bind_name
+ bind_name = self._truncated_identifier("bindparam", bind_name)
+ # add to bind_names for translation
+ self.bind_names[bindparam] = bind_name
+
return bind_name
def _truncated_identifier(self, ident_class, name):
if (ident_class, name) in self.generated_ids:
return self.generated_ids[(ident_class, name)]
- if len(name) > self.dialect.max_identifier_length():
+
+ anonname = self._anonymize(name)
+ if len(anonname) > self.dialect.max_identifier_length():
counter = self.generated_ids.get(ident_class, 1)
truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
self.generated_ids[ident_class] = counter + 1
else:
- truncname = name
+ truncname = anonname
self.generated_ids[(ident_class, name)] = truncname
return truncname
+
+ def _anonymize(self, name):
+ def anon(match):
+ (ident, derived) = match.group(1,2)
+ if ('anonymous', ident) in self.generated_ids:
+ return self.generated_ids[('anonymous', ident)]
+ else:
+ anonymous_counter = self.generated_ids.get('anonymous', 1)
+ newname = derived + "_" + str(anonymous_counter)
+ self.generated_ids['anonymous'] = anonymous_counter + 1
+ self.generated_ids[('anonymous', ident)] = newname
+ return newname
+ return re.sub(r'{ANON (-?\d+) (.*)}', anon, name)
def bindparam_string(self, name):
return self.bindtemplate % name
- def visit_alias(self, alias):
- self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
- self.strings[alias] = self.get_str(alias.original)
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ if asfrom:
+ return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
+ else:
+ return self.process(alias.original, **kwargs)
- def visit_select(self, select):
- # the actual list of columns to print in the SELECT column list.
- inner_columns = util.OrderedDict()
+ def label_select_column(self, select, column):
+ """convert a column from a select's "columns" clause.
+
+ given a select() and a column element from its inner_columns collection, return a
+ Label object if this column should be labeled in the columns clause. Otherwise,
+ return None and the column will be used as-is.
+
+ The calling method will traverse the returned label to acquire its string
+ representation.
+ """
+
+ # SQLite doesnt like selecting from a subquery where the column
+ # names look like table.colname. so if column is in a "selected from"
+ # subquery, label it synoymously with its column name
+ if \
+ self.correlate_state[select].get('is_selected_from', False) and \
+ isinstance(column, sql._ColumnClause) and \
+ not column.is_literal and \
+ column.table is not None and \
+ not isinstance(column.table, sql.Select):
+ return column.label(column.name)
+ else:
+ return None
+
+ def visit_select(self, select, asfrom=False, **kwargs):
+ select._calculate_correlations(self.correlate_state)
self.select_stack.append(select)
- for c in select._raw_columns:
- if hasattr(c, '_selectable'):
- s = c._selectable()
- else:
- self.traverse(c)
- inner_columns[self.get_str(c)] = c
- continue
- for co in s.columns:
- if select.use_labels:
- labelname = co._label
- if labelname is not None:
- l = co.label(labelname)
- self.traverse(l)
- inner_columns[labelname] = l
- else:
- self.traverse(co)
- inner_columns[self.get_str(co)] = co
- # TODO: figure this out, a ColumnClause with a select as a parent
- # is different from any other kind of parent
- elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select):
- # SQLite doesnt like selecting from a subquery where the column
- # names look like table.colname, so add a label synonomous with
- # the column name
- l = co.label(co.name)
- self.traverse(l)
- inner_columns[self.get_str(l.obj)] = l
+
+ # the actual list of columns to print in the SELECT column list.
+ inner_columns = util.OrderedSet()
+
+ froms = select._get_display_froms(self.correlate_state)
+
+ for co in select.inner_columns:
+ if select.use_labels:
+ labelname = co._label
+ if labelname is not None:
+ l = co.label(labelname)
+ inner_columns.add(self.process(l))
else:
self.traverse(co)
- inner_columns[self.get_str(co)] = co
+ inner_columns.add(self.process(co))
+ else:
+ l = self.label_select_column(select, co)
+ if l is not None:
+ inner_columns.add(self.process(l))
+ else:
+ inner_columns.add(self.process(co))
+
self.select_stack.pop(-1)
- collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
+ collist = string.join(inner_columns.difference(util.Set([None])), ', ')
- text = "SELECT "
- text += self.visit_select_precolumns(select)
+ text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " "
+ text += self.get_select_precolumns(select)
text += collist
- whereclause = select.whereclause
-
- froms = []
- for f in select.froms:
-
- if self.parameters is not None:
- # TODO: whack this feature in 0.4
- # look at our own parameters, see if they
- # are all present in the form of BindParamClauses. if
- # not, then append to the above whereclause column conditions
- # matching those keys
- for c in f.columns:
- if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
- value = self.parameters[c.key]
- else:
- continue
- clause = c==value
- if whereclause is not None:
- whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause]))
- else:
- whereclause = clause
- self.traverse(whereclause)
-
- # special thingy used by oracle to redefine a join
+ whereclause = select._whereclause
+
+ from_strings = []
+ for f in froms:
+ from_strings.append(self.process(f, asfrom=True))
+
w = self.get_whereclause(f)
if w is not None:
- # TODO: move this more into the oracle module
if whereclause is not None:
- whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w]))
+ whereclause = sql.and_(w, whereclause)
else:
whereclause = w
- t = self.get_from_text(f)
- if t is not None:
- froms.append(t)
-
if len(froms):
text += " \nFROM "
- text += string.join(froms, ', ')
+ text += string.join(from_strings, ', ')
else:
text += self.default_from()
if whereclause is not None:
- t = self.get_str(whereclause)
+ t = self.process(whereclause)
if t:
text += " \nWHERE " + t
- group_by = self.get_str(select.group_by_clause)
+ group_by = self.process(select._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
- if select.having is not None:
- t = self.get_str(select.having)
+ if select._having is not None:
+ t = self.process(select._having)
if t:
text += " \nHAVING " + t
text += self.order_by_clause(select)
- text += self.visit_select_postclauses(select)
+ text += (select._limit or select._offset) and self.limit_clause(select) or ""
text += self.for_update_clause(select)
- self.strings[select] = text
- self.froms[select] = "(" + text + ")"
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just before column list."""
-
- return select.distinct and "DISTINCT " or ""
-
- def visit_select_postclauses(self, select):
- """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
-
- Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
- """
-
- return (select.limit or select.offset) and self.limit_clause(select) or ""
+ return select._distinct and "DISTINCT " or ""
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.process(select._order_by_clause)
if order_by:
return " ORDER BY " + order_by
else:
@@ -557,175 +592,103 @@ class ANSICompiler(sql.Compiled):
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def visit_table(self, table):
- self.froms[table] = self.preparer.format_table(table)
- self.strings[table] = ""
-
- def visit_join(self, join):
- righttext = self.get_from_text(join.right)
- if join.right._group_parenthesized():
- righttext = "(" + righttext + ")"
- if join.isouter:
- self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext +
- " ON " + self.get_str(join.onclause))
+ def visit_table(self, table, asfrom=False, **kwargs):
+ if asfrom:
+ return self.preparer.format_table(table)
else:
- self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
- " ON " + self.get_str(join.onclause))
- self.strings[join] = self.froms[join]
-
- def visit_insert_column_default(self, column, default, parameters):
- """Called when visiting an ``Insert`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object, add a blank *placeholder* parameter so the ``Insert``
- gets compiled with this column's name in its column and
- ``VALUES`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_update_column_default(self, column, default, parameters):
- """Called when visiting an ``Update`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object as an onupdate, add a blank *placeholder* parameter so
- the ``Update`` gets compiled with this column's name as one of
- its ``SET`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_insert_sequence(self, column, sequence, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden compilers that support sequences to
- place a blank *placeholder* parameter for each column in the
- table that contains a Sequence object, so the Insert gets
- compiled with this column's name in its column and ``VALUES``
- clauses.
- """
-
- pass
-
- def visit_insert_column(self, column, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden by compilers who disallow NULL columns
- being set in an ``Insert`` where there is a default value on
- the column (i.e. postgres), to remove the column for which
- there is a NULL insert from the parameter list.
- """
+ return ""
- pass
+ def visit_join(self, join, asfrom=False, **kwargs):
+ return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
+ def uses_sequences_for_inserts(self):
+ return False
+
def visit_insert(self, insert_stmt):
- # scan the table's columns for defaults that have to be pre-set for an INSERT
- # add these columns to the parameter list via visit_insert_XXX methods
- default_params = {}
+
+ # search for columns who will be required to have an explicit bound value.
+ # for inserts, this includes Python-side defaults, columns with sequences for dialects
+ # that support sequences, and primary key columns for dialects that explicitly insert
+ # pre-generated primary key values
+ required_cols = util.Set()
class DefaultVisitor(schema.SchemaVisitor):
- def visit_column(s, c):
- self.visit_insert_column(c, default_params)
+ def visit_column(s, cd):
+ if c.primary_key and self.uses_sequences_for_inserts():
+ required_cols.add(c)
def visit_column_default(s, cd):
- self.visit_insert_column_default(c, cd, default_params)
+ required_cols.add(c)
def visit_sequence(s, seq):
- self.visit_insert_sequence(c, seq, default_params)
+ if self.uses_sequences_for_inserts():
+ required_cols.add(c)
vis = DefaultVisitor()
for c in insert_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, default_params)
+ colparams = self._get_colparams(insert_stmt, required_cols)
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- if p.shortname is not None:
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.inline_params.add(col)
- self.traverse(p)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.get_str(p) + ")"
- else:
- return self.get_str(p)
-
- text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
- " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")")
-
- self.strings[insert_stmt] = text
+ return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
+ " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
def visit_update(self, update_stmt):
- # scan the table's columns for onupdates that have to be pre-set for an UPDATE
- # add these columns to the parameter list via visit_update_XXX methods
- default_params = {}
+ update_stmt._calculate_correlations(self.correlate_state)
+
+ # search for columns who will be required to have an explicit bound value.
+ # for updates, this includes Python-side "onupdate" defaults.
+ required_cols = util.Set()
class OnUpdateVisitor(schema.SchemaVisitor):
def visit_column_onupdate(s, cd):
- self.visit_update_column_default(c, cd, default_params)
+ required_cols.add(c)
vis = OnUpdateVisitor()
for c in update_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isupdate = True
- colparams = self._get_colparams(update_stmt, default_params)
-
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.traverse(p)
- self.inline_params.add(col)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.get_str(p) + ")"
- else:
- return self.get_str(p)
-
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
-
- if update_stmt.whereclause:
- text += " WHERE " + self.get_str(update_stmt.whereclause)
+ colparams = self._get_colparams(update_stmt, required_cols)
- self.strings[update_stmt] = text
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
+ if update_stmt._whereclause:
+ text += " WHERE " + self.process(update_stmt._whereclause)
- def _get_colparams(self, stmt, default_params):
- """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples.
-
- Each tuple will contain the ``Column`` and a ``ClauseElement``
- representing the value to be set (usually a ``_BindParamClause``,
- but could also be other SQL expressions.)
-
- The list of tuples will determine the columns that are
- actually rendered into the ``SET``/``VALUES`` clause of the
- rendered ``UPDATE``/``INSERT`` statement. It will also
- determine how to generate the list/dictionary of bind
- parameters at execution time (i.e. ``get_params()``).
+ return text
- This list takes into account the `values` keyword specified
- to the statement, the parameters sent to this Compiled
- instance, and the default bind parameter values corresponding
- to the dialect's behavior for otherwise unspecified primary
- key columns.
+ def _get_colparams(self, stmt, required_cols):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ This method may generate new bind params within this compiled
+ based on the given set of "required columns", which are required
+ to have a value set in the statement.
"""
+ def create_bind_param(col, value):
+ bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True)
+ self.binds[col.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.parameters is None and stmt.parameters is None:
- return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns]
+ return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
+
+ def create_clause_param(col, value):
+ self.traverse(value)
+ self.inline_params.add(col)
+ return self.process(value)
+
+ self.inline_params = util.Set()
def to_col(key):
if not isinstance(key, sql._ColumnClause):
@@ -744,29 +707,43 @@ class ANSICompiler(sql.Compiled):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(to_col(k), v)
- for k, v in default_params.iteritems():
- parameters.setdefault(to_col(k), v)
+ for col in required_cols:
+ parameters.setdefault(col, None)
# create a list of column assignment clauses as tuples
values = []
for c in stmt.table.columns:
- if parameters.has_key(c):
+ if c in parameters:
value = parameters[c]
if sql._is_literal(value):
- value = sql.bindparam(c.key, value, type=c.type, unique=True)
+ value = create_bind_param(c, value)
+ else:
+ value = create_clause_param(c, value)
values.append((c, value))
+
return values
def visit_delete(self, delete_stmt):
+ delete_stmt._calculate_correlations(self.correlate_state)
+
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
- if delete_stmt.whereclause:
- text += " WHERE " + self.get_str(delete_stmt.whereclause)
+ if delete_stmt._whereclause:
+ text += " WHERE " + self.process(delete_stmt._whereclause)
- self.strings[delete_stmt] = text
+ return text
+
+ def visit_savepoint(self, savepoint_stmt):
+ return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+
+ def visit_release_savepoint(self, savepoint_stmt):
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+
def __str__(self):
- return self.get_str(self.statement)
+ return self.string
class ANSISchemaBase(engine.SchemaIterator):
def find_alterables(self, tables):
@@ -795,7 +772,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
@@ -803,9 +780,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
- #if column.onupdate is not None:
- # column.onupdate.accept_visitor(visitor)
+ self.traverse_single(column.default)
self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
@@ -820,20 +795,20 @@ class ANSISchemaGenerator(ANSISchemaBase):
if column.primary_key:
first_pk = True
for constraint in column.constraints:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
if len(table.primary_key):
- table.primary_key.accept_visitor(self)
+ self.traverse_single(table.primary_key)
for constraint in [c for c in table.constraints if c is not table.primary_key]:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
self.append("\n)%s\n\n" % self.post_create_table(table))
self.execute()
if hasattr(table, 'indexes'):
for index in table.indexes:
- index.accept_visitor(self)
+ self.traverse_single(index)
def post_create_table(self, table):
return ''
@@ -870,7 +845,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+ self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
def visit_foreign_key_constraint(self, constraint):
if constraint.use_alter and self.dialect.supports_alter():
@@ -889,9 +864,9 @@ class ANSISchemaGenerator(ANSISchemaBase):
self.append("CONSTRAINT %s " %
preparer.format_constraint(constraint))
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- string.join([preparer.format_column(f.parent) for f in constraint.elements], ', '),
+ ', '.join([preparer.format_column(f.parent) for f in constraint.elements]),
preparer.format_table(list(constraint.elements)[0].column.table),
- string.join([preparer.format_column(f.column) for f in constraint.elements], ', ')
+ ', '.join([preparer.format_column(f.column) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -903,17 +878,17 @@ class ANSISchemaGenerator(ANSISchemaBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint])))
def visit_column(self, column):
pass
def visit_index(self, index):
- preparer = self.preparer
- self.append('CREATE ')
+ preparer = self.preparer
+ self.append("CREATE ")
if index.unique:
- self.append('UNIQUE ')
- self.append('INDEX %s ON %s (%s)' \
+ self.append("UNIQUE ")
+ self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
string.join([preparer.format_column(c) for c in index.columns], ', ')))
@@ -933,7 +908,7 @@ class ANSISchemaDropper(ANSISchemaBase):
for alterable in self.find_alterables(collection):
self.drop_foreignkey(alterable)
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
def visit_index(self, index):
self.append("\nDROP INDEX " + self.preparer.format_index(index))
@@ -948,7 +923,7 @@ class ANSISchemaDropper(ANSISchemaBase):
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
+ self.traverse_single(column.default)
self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
@@ -1048,17 +1023,17 @@ class ANSIIdentifierPreparer(object):
def should_quote(self, object):
return object.quote or self._requires_quotes(object.name, object.case_sensitive)
- def is_natural_case(self, object):
- return object.quote or self._requires_quotes(object.name, object.case_sensitive)
-
def format_sequence(self, sequence):
return self.__generic_obj_format(sequence, sequence.name)
def format_label(self, label, name=None):
return self.__generic_obj_format(label, name or label.name)
- def format_alias(self, alias):
- return self.__generic_obj_format(alias, alias.name)
+ def format_alias(self, alias, name=None):
+ return self.__generic_obj_format(alias, name or alias.name)
+
+ def format_savepoint(self, savepoint):
+ return self.__generic_obj_format(savepoint, savepoint)
def format_constraint(self, constraint):
return self.__generic_obj_format(constraint, constraint.name)
@@ -1076,25 +1051,25 @@ class ANSIIdentifierPreparer(object):
result = self.__generic_obj_format(table, table.schema) + "." + result
return result
- def format_column(self, column, use_table=False, name=None):
+ def format_column(self, column, use_table=False, name=None, table_name=None):
"""Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name)
else:
return self.__generic_obj_format(column, name)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + name
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + name
else:
return name
- def format_column_with_table(self, column, column_name=None):
+ def format_column_with_table(self, column, column_name=None, table_name=None):
"""Prepare a quoted column name with table name."""
- return self.format_column(column, use_table=True, name=column_name)
+ return self.format_column(column, use_table=True, name=column_name, table_name=table_name)
dialect = ANSIDialect
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index a02781c84..07f07644f 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -5,15 +5,11 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, types
+import warnings
-from sqlalchemy import util
+from sqlalchemy import util, sql, schema, ansisql, exceptions
import sqlalchemy.engine.default as default
-import sqlalchemy.sql as sql
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
_initialized_kb = False
@@ -176,7 +172,7 @@ class FBDialect(ansisql.ANSIDialect):
else:
return False
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
#TODO: map these better
column_func = {
14 : lambda r: sqltypes.String(r['FLEN']), # TEXT
@@ -254,11 +250,20 @@ class FBDialect(ansisql.ANSIDialect):
while row:
name = row['FNAME']
- args = [lower_if_possible(name)]
+ python_name = lower_if_possible(name)
+ if include_columns and python_name not in include_columns:
+ continue
+ args = [python_name]
kw = {}
# get the data types and lengths
- args.append(column_func[row['FTYPE']](row))
+ coltype = column_func.get(row['FTYPE'], None)
+ if coltype is None:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (str(row['FTYPE']), name)))
+ coltype = sqltypes.NULLTYPE
+ else:
+ coltype = coltype(row)
+ args.append(coltype)
# is it a primary key?
kw['primary_key'] = name in pkfields
@@ -301,39 +306,39 @@ class FBDialect(ansisql.ANSIDialect):
class FBCompiler(ansisql.ANSICompiler):
"""Firebird specific idiosincrasies"""
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
# Override to not use the AS keyword which FB 1.5 does not like
- self.froms[alias] = self.get_from_text(alias.original) + " " + self.preparer.format_alias(alias)
- self.strings[alias] = self.get_str(alias.original)
+ if asfrom:
+ return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias)
+ else:
+ return self.process(alias.original, asfrom=True)
def visit_function(self, func):
if len(func.clauses):
- super(FBCompiler, self).visit_function(func)
+ return super(FBCompiler, self).visit_function(func)
else:
- self.strings[func] = func.name
+ return func.name
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ def uses_sequences_for_inserts(self):
+ return True
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right
after the ``SELECT``...
"""
result = ""
- if select.limit:
- result += " FIRST %d " % select.limit
- if select.offset:
- result +=" SKIP %d " % select.offset
- if select.distinct:
+ if select._limit:
+ result += " FIRST %d " % select._limit
+ if select._offset:
+ result +=" SKIP %d " % select._offset
+ if select._distinct:
result += " DISTINCT "
return result
def limit_clause(self, select):
- """Already taken care of in the `visit_select_precolumns` method."""
+ """Already taken care of in the `get_select_precolumns` method."""
return ""
@@ -364,7 +369,7 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper):
class FBDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.connection)
+ c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection)
return self.connection.execute_compiled(c).scalar()
def visit_sequence(self, seq):
diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py
index 81c44dcaa..93f47de15 100644
--- a/lib/sqlalchemy/databases/information_schema.py
+++ b/lib/sqlalchemy/databases/information_schema.py
@@ -1,4 +1,6 @@
-from sqlalchemy import sql, schema, exceptions, select, MetaData, Table, Column, String, Integer
+import sqlalchemy.sql as sql
+import sqlalchemy.exceptions as exceptions
+from sqlalchemy import select, MetaData, Table, Column, String, Integer
from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint
ischema = MetaData()
@@ -96,8 +98,7 @@ class ISchema(object):
return self.cache[name]
-def reflecttable(connection, table, ischema_names):
-
+def reflecttable(connection, table, include_columns, ischema_names):
key_constraints = pg_key_constraints
if table.schema is not None:
@@ -128,7 +129,9 @@ def reflecttable(connection, table, ischema_names):
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
-
+ if include_columns and name not in include_columns:
+ continue
+
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
@@ -139,7 +142,7 @@ def reflecttable(connection, table, ischema_names):
colargs= []
if default is not None:
colargs.append(PassiveDefault(sql.text(default)))
- table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+ table.append_column(Column(name, coltype, nullable=nullable, *colargs))
if not found_table:
raise exceptions.NoSuchTableError(table.name)
diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py
index 2fb508280..f3a6cf60e 100644
--- a/lib/sqlalchemy/databases/informix.py
+++ b/lib/sqlalchemy/databases/informix.py
@@ -5,20 +5,11 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import datetime, warnings
-import sys, StringIO, string , random
-import datetime
-from decimal import Decimal
-
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
+from sqlalchemy import sql, schema, ansisql, exceptions, pool
import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
-import sqlalchemy.pool as pool
# for offset
@@ -128,7 +119,7 @@ class InfoBoolean(sqltypes.Boolean):
elif value is None:
return None
else:
- return value and True or False
+ return value and True or False
colspecs = {
@@ -262,7 +253,7 @@ class InfoDialect(ansisql.ANSIDialect):
cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() )
return bool( cursor.fetchone() is not None )
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() )
rows = c.fetchall()
if not rows :
@@ -289,6 +280,10 @@ class InfoDialect(ansisql.ANSIDialect):
raise exceptions.NoSuchTableError(table.name)
for name , colattr , collength , default , colno in rows:
+ name = name.lower()
+ if include_columns and name not in include_columns:
+ continue
+
# in 7.31, coltype = 0x000
# ^^-- column type
# ^-- 1 not null , 0 null
@@ -306,14 +301,16 @@ class InfoDialect(ansisql.ANSIDialect):
scale = 0
coltype = InfoNumeric(precision, scale)
else:
- coltype = ischema_names.get(coltype)
+ try:
+ coltype = ischema_names[coltype]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
+ coltype = sqltypes.NULLTYPE
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
- name = name.lower()
-
table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs))
# FK
@@ -372,20 +369,20 @@ class InfoCompiler(ansisql.ANSICompiler):
def default_from(self):
return " from systables where tabname = 'systables' "
- def visit_select_precolumns( self , select ):
- s = select.distinct and "DISTINCT " or ""
+ def get_select_precolumns( self , select ):
+ s = select._distinct and "DISTINCT " or ""
# only has limit
- if select.limit:
- off = select.offset or 0
- s += " FIRST %s " % ( select.limit + off )
+ if select._limit:
+ off = select._offset or 0
+ s += " FIRST %s " % ( select._limit + off )
else:
s += ""
return s
def visit_select(self, select):
- if select.offset:
- self.offset = select.offset
- self.limit = select.limit or 0
+ if select._offset:
+ self.offset = select._offset
+ self.limit = select._limit or 0
# the column in order by clause must in select too
def __label( c ):
@@ -393,13 +390,14 @@ class InfoCompiler(ansisql.ANSICompiler):
return c._label.lower()
except:
return ''
-
+
+ # TODO: dont modify the original select, generate a new one
a = [ __label(c) for c in select._raw_columns ]
for c in select.order_by_clause.clauses:
if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
select.append_column( c )
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select)
def limit_clause(self, select):
return ""
@@ -414,23 +412,20 @@ class InfoCompiler(ansisql.ANSICompiler):
def visit_function( self , func ):
if func.name.lower() == 'current_date':
- self.strings[func] = "today"
+ return "today"
elif func.name.lower() == 'current_time':
- self.strings[func] = "CURRENT HOUR TO SECOND"
+ return "CURRENT HOUR TO SECOND"
elif func.name.lower() in ( 'current_timestamp' , 'now' ):
- self.strings[func] = "CURRENT YEAR TO SECOND"
+ return "CURRENT YEAR TO SECOND"
else:
- ansisql.ANSICompiler.visit_function( self , func )
+ return ansisql.ANSICompiler.visit_function( self , func )
def visit_clauselist(self, list):
try:
li = [ c for c in list.clauses if c.name != 'oid' ]
except:
li = [ c for c in list.clauses ]
- if list.parens:
- self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in li] if s is not None ], ', ') + ")"
- else:
- self.strings[list] = string.join([s for s in [self.get_str(c) for c in li] if s is not None], ', ')
+ return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, first_pk=False):
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index ba1c0fd9d..206291404 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -25,7 +25,7 @@
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
-* ``select.limit`` implemented as ``SELECT TOP n``
+* ``select._limit`` implemented as ``SELECT TOP n``
Known issues / TODO:
@@ -39,16 +39,11 @@ Known issues / TODO:
"""
-import sys, StringIO, string, types, re, datetime, random
+import datetime, random, warnings
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+from sqlalchemy import sql, schema, ansisql, exceptions
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
-
+from sqlalchemy.engine import default
class MSNumeric(sqltypes.Numeric):
def convert_result_value(self, value, dialect):
@@ -500,7 +495,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
row = c.fetchone()
return row is not None
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
import sqlalchemy.databases.information_schema as ischema
# Get base columns
@@ -532,16 +527,22 @@ class MSSQLDialect(ansisql.ANSIDialect):
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
+ if include_columns and name not in include_columns:
+ continue
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
args.append(a)
- coltype = self.ischema_names[type]
+ coltype = self.ischema_names.get(type, None)
if coltype == MSString and charlen == -1:
coltype = MSText()
else:
- if coltype == MSNVarchar and charlen == -1:
+ if coltype is None:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name)))
+ coltype = sqltypes.NULLTYPE
+
+ elif coltype == MSNVarchar and charlen == -1:
charlen = None
coltype = coltype(*args)
colargs= []
@@ -812,12 +813,12 @@ class MSSQLCompiler(ansisql.ANSICompiler):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
- s = select.distinct and "DISTINCT " or ""
- if select.limit:
- s += "TOP %s " % (select.limit,)
- if select.offset:
+ s = select._distinct and "DISTINCT " or ""
+ if select._limit:
+ s += "TOP %s " % (select._limit,)
+ if select._offset:
raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
@@ -825,49 +826,50 @@ class MSSQLCompiler(ansisql.ANSICompiler):
# Limit in mssql is after the select keyword
return ""
- def visit_table(self, table):
+ def _schema_aliased_table(self, table):
+ if getattr(table, 'schema', None) is not None:
+ if not self.tablealiases.has_key(table):
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_table(self, table, mssql_aliased=False, **kwargs):
+ if mssql_aliased:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
# alias schema-qualified tables
- if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
- alias = table.alias()
- self.tablealiases[table] = alias
- self.traverse(alias)
- self.froms[('alias', table)] = self.froms[table]
- for c in alias.c:
- self.traverse(c)
- self.traverse(alias.oid_column)
- self.tablealiases[alias] = self.froms[table]
- self.froms[table] = self.froms[alias]
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=True, **kwargs)
else:
- super(MSSQLCompiler, self).visit_table(table)
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
- def visit_alias(self, alias):
+ def visit_alias(self, alias, **kwargs):
# translate for schema-qualified table aliases
- if self.froms.has_key(('alias', alias.original)):
- self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
- self.strings[alias] = ""
- else:
- super(MSSQLCompiler, self).visit_alias(alias)
+ self.tablealiases[alias.original] = alias
+ return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
def visit_column(self, column):
- # translate for schema-qualified table aliases
- super(MSSQLCompiler, self).visit_column(column)
- if column.table is not None and self.tablealiases.has_key(column.table):
- self.strings[column] = \
- self.strings[self.tablealiases[column.table].corresponding_column(column)]
+ if column.table is not None:
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ return self.process(t.corresponding_column(column))
+ return super(MSSQLCompiler, self).visit_column(column)
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=':
- binary.left, binary.right = binary.right, binary.left
- super(MSSQLCompiler, self).visit_binary(binary)
-
- def visit_select(self, select):
- # label function calls, so they return a name in cursor.description
- for i,c in enumerate(select._raw_columns):
- if isinstance(c, sql._Function):
- select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
+ if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
+ return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+ else:
+ return super(MSSQLCompiler, self).visit_binary(binary)
- super(MSSQLCompiler, self).visit_select(select)
+ def label_select_column(self, select, column):
+ if isinstance(column, sql._Function):
+ return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
+ else:
+ return super(MSSQLCompiler, self).label_select_column(select, column)
function_rewrites = {'current_date': 'getdate',
'length': 'len',
@@ -881,10 +883,10 @@ class MSSQLCompiler(ansisql.ANSICompiler):
return ''
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.process(select._order_by_clause)
# MSSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not select.is_subquery or select.limit):
+ if order_by and (not self.is_subquery(select) or select._limit):
return " ORDER BY " + order_by
else:
return ""
@@ -916,10 +918,12 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote_identifier(index.name)))
+ self.preparer.quote_identifier(index.table.name),
+ self.preparer.quote_identifier(index.name)
+ ))
self.execute()
+
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
# TODO: does ms-sql have standalone sequences ?
pass
@@ -940,4 +944,3 @@ dialect = MSSQLDialect
-
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index bac0e5e12..26800e32b 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, datetime, inspect, warnings, weakref
+import re, datetime, inspect, warnings, weakref, operator
from sqlalchemy import sql, schema, ansisql
from sqlalchemy.engine import default
@@ -12,13 +12,13 @@ import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
import sqlalchemy.util as util
from array import array as _array
+from decimal import Decimal
try:
from threading import Lock
except ImportError:
from dummy_threading import Lock
-
RESERVED_WORDS = util.Set(
['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
@@ -60,7 +60,6 @@ RESERVED_WORDS = util.Set(
'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
'read_only', 'read_write', # 5.1
])
-
_per_connection_mutex = Lock()
class _NumericType(object):
@@ -137,7 +136,7 @@ class _StringType(object):
class MSNumeric(sqltypes.Numeric, _NumericType):
"""MySQL NUMERIC type"""
- def __init__(self, precision = 10, length = 2, **kw):
+ def __init__(self, precision = 10, length = 2, asdecimal=True, **kw):
"""Construct a NUMERIC.
precision
@@ -157,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
"""
_NumericType.__init__(self, **kw)
- sqltypes.Numeric.__init__(self, precision, length)
-
+ sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal)
+
def get_col_spec(self):
if self.precision is None:
return self._extend("NUMERIC")
else:
return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class MSDecimal(MSNumeric):
"""MySQL DECIMAL type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DECIMAL.
precision
@@ -187,7 +195,7 @@ class MSDecimal(MSNumeric):
underlying database API, which continue to be numeric.
"""
- super(MSDecimal, self).__init__(precision, length, **kw)
+ super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is None:
@@ -200,7 +208,7 @@ class MSDecimal(MSNumeric):
class MSDouble(MSNumeric):
"""MySQL DOUBLE type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DOUBLE.
precision
@@ -222,7 +230,7 @@ class MSDouble(MSNumeric):
if ((precision is None and length is not None) or
(precision is not None and length is None)):
raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.")
- super(MSDouble, self).__init__(precision, length, **kw)
+ super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is not None and self.length is not None:
@@ -235,7 +243,7 @@ class MSDouble(MSNumeric):
class MSFloat(sqltypes.Float, _NumericType):
"""MySQL FLOAT type"""
- def __init__(self, precision=10, length=None, **kw):
+ def __init__(self, precision=10, length=None, asdecimal=False, **kw):
"""Construct a FLOAT.
precision
@@ -257,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType):
if length is not None:
self.length=length
_NumericType.__init__(self, **kw)
- sqltypes.Float.__init__(self, precision)
+ sqltypes.Float.__init__(self, precision, asdecimal=asdecimal)
def get_col_spec(self):
if hasattr(self, 'length') and self.length is not None:
@@ -267,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType):
else:
return self._extend("FLOAT")
+ def convert_bind_param(self, value, dialect):
+ return value
+
+
class MSInteger(sqltypes.Integer, _NumericType):
"""MySQL INTEGER type"""
@@ -955,7 +967,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
+
+ def is_select(self):
+ return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None
+
class MySQLDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
@@ -1044,6 +1059,27 @@ class MySQLDialect(ansisql.ANSIDialect):
except:
pass
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("XA RECOVER"))
+ return [row['data'][0:row['gtrid_length']] for row in resultset]
+
def is_disconnect(self, e):
return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
@@ -1088,7 +1124,7 @@ class MySQLDialect(ansisql.ANSIDialect):
version.append(n)
return tuple(version)
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
"""Load column definitions from the server."""
decode_from = self._detect_charset(connection)
@@ -1111,6 +1147,9 @@ class MySQLDialect(ansisql.ANSIDialect):
# leave column names as unicode
name = name.decode(decode_from)
+
+ if include_columns and name not in include_columns:
+ continue
match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
col_type = match.group(1)
@@ -1118,7 +1157,11 @@ class MySQLDialect(ansisql.ANSIDialect):
extra_1 = match.group(3)
extra_2 = match.group(4)
- coltype = ischema_names.get(col_type, MSString)
+ try:
+ coltype = ischema_names[col_type]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name)))
+ coltype = sqltypes.NULLTYPE
kw = {}
if extra_1 is not None:
@@ -1156,7 +1199,6 @@ class MySQLDialect(ansisql.ANSIDialect):
if not row:
raise exceptions.NoSuchTableError(table.fullname)
desc = row[1].strip()
- row.close()
tabletype = ''
lastparen = re.search(r'\)[^\)]*\Z', desc)
@@ -1223,7 +1265,6 @@ class MySQLDialect(ansisql.ANSIDialect):
cs = True
else:
cs = row[1] in ('0', 'OFF' 'off')
- row.close()
cache['lower_case_table_names'] = cs
self.per_connection[raw_connection] = cache
return cache.get('lower_case_table_names')
@@ -1266,14 +1307,21 @@ class _MySQLPythonRowProxy(object):
class MySQLCompiler(ansisql.ANSICompiler):
- def visit_cast(self, cast):
-
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
+ operator.mod : '%%'
+ }
+ )
+
+ def visit_cast(self, cast, **kwargs):
if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
- return super(MySQLCompiler, self).visit_cast(cast)
+ return super(MySQLCompiler, self).visit_cast(cast, **kwargs)
else:
# so just skip the CAST altogether for now.
# TODO: put whatever MySQL does for CAST here.
- self.strings[cast] = self.strings[cast.clause]
+ return self.process(cast.clause)
def for_update_clause(self, select):
if select.for_update == 'read':
@@ -1283,20 +1331,15 @@ class MySQLCompiler(ansisql.ANSICompiler):
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
- # striaght from the MySQL docs, I kid you not
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ # straight from the MySQL docs, I kid you not
text += " \n LIMIT 18446744073709551615"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def binary_operator_string(self, binary):
- if binary.operator == '%':
- return '%%'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index 9d7d6a112..d3aa2e268 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -5,9 +5,9 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, re, warnings
+import re, warnings, operator
-from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
+from sqlalchemy import util, sql, schema, ansisql, exceptions, logging
from sqlalchemy.engine import default, base
import sqlalchemy.types as sqltypes
@@ -88,8 +88,11 @@ class OracleText(sqltypes.TEXT):
def convert_result_value(self, value, dialect):
if value is None:
return None
- else:
+ elif hasattr(value, 'read'):
+ # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str
return super(OracleText, self).convert_result_value(value.read(), dialect)
+ else:
+ return super(OracleText, self).convert_result_value(value, dialect)
class OracleRaw(sqltypes.Binary):
@@ -178,25 +181,31 @@ class OracleExecutionContext(default.DefaultExecutionContext):
super(OracleExecutionContext, self).pre_exec()
if self.dialect.auto_setinputsizes:
self.set_input_sizes()
+ if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list):
+ for key in self.compiled_parameters:
+ (bindparam, name, value) = self.compiled_parameters.get_parameter(key)
+ if bindparam.isoutparam:
+ dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+ if not hasattr(self, 'out_parameters'):
+ self.out_parameters = {}
+ self.out_parameters[name] = self.cursor.var(dbtype)
+ self.parameters[name] = self.out_parameters[name]
def get_result_proxy(self):
+ if hasattr(self, 'out_parameters'):
+ if self.compiled_parameters is not None:
+ for k in self.out_parameters:
+ type = self.compiled_parameters.get_type(k)
+ self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect)
+ else:
+ for k in self.out_parameters:
+ self.out_parameters[k] = self.out_parameters[k].getvalue()
+
if self.cursor.description is not None:
- if self.dialect.auto_convert_lobs and self.typemap is None:
- typemap = {}
- binary = False
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- binary = True
- typemap[column[0].lower()] = OracleBinary()
- self.typemap = typemap
- if binary:
+ for column in self.cursor.description:
+ type_code = column[1]
+ if type_code in self.dialect.ORACLE_BINARY_TYPES:
return base.BufferedColumnResultProxy(self)
- else:
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- return base.BufferedColumnResultProxy(self)
return base.ResultProxy(self)
@@ -208,11 +217,26 @@ class OracleDialect(ansisql.ANSIDialect):
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
self.auto_convert_lobs = auto_convert_lobs
+
if self.dbapi is not None:
self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
else:
self.ORACLE_BINARY_TYPES = []
+ def dbapi_type_map(self):
+ if self.dbapi is None or not self.auto_convert_lobs:
+ return {}
+ else:
+ return {
+ self.dbapi.NUMBER: OracleInteger(),
+ self.dbapi.CLOB: OracleText(),
+ self.dbapi.BLOB: OracleBinary(),
+ self.dbapi.STRING: OracleString(),
+ self.dbapi.TIMESTAMP: OracleTimestamp(),
+ self.dbapi.BINARY: OracleRaw(),
+ datetime.datetime: OracleDate()
+ }
+
def dbapi(cls):
import cx_Oracle
return cx_Oracle
@@ -251,7 +275,7 @@ class OracleDialect(ansisql.ANSIDialect):
return 30
def oid_column_name(self, column):
- if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select):
+ if not isinstance(column.table, (sql.TableClause, sql.Select)):
return None
else:
return "rowid"
@@ -341,7 +365,7 @@ class OracleDialect(ansisql.ANSIDialect):
return name, owner, dblink
raise
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
preparer = self.identifier_preparer
if not preparer.should_quote(table):
name = table.name.upper()
@@ -363,6 +387,13 @@ class OracleDialect(ansisql.ANSIDialect):
#print "ROW:" , row
(colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+ # if name comes back as all upper, assume its case folded
+ if (colname.upper() == colname):
+ colname = colname.lower()
+
+ if include_columns and colname not in include_columns:
+ continue
+
# INTEGER if the scale is 0 and precision is null
# NUMBER if the scale and precision are both null
# NUMBER(9,2) if the precision is 9 and the scale is 2
@@ -382,16 +413,13 @@ class OracleDialect(ansisql.ANSIDialect):
try:
coltype = ischema_names[coltype]
except KeyError:
- raise exceptions.AssertionError("Can't get coltype for type '%s' on colname '%s'" % (coltype, colname))
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname)))
+ coltype = sqltypes.NULLTYPE
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
- # if name comes back as all upper, assume its case folded
- if (colname.upper() == colname):
- colname = colname.lower()
-
table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
if not len(table.columns):
@@ -458,16 +486,27 @@ class OracleDialect(ansisql.ANSIDialect):
OracleDialect.logger = logging.class_logger(OracleDialect)
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = 'outer_join_column'
+ def __init__(self, column):
+ self.column = column
+
class OracleCompiler(ansisql.ANSICompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
the use_ansi flag is False.
"""
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+ }
+ )
+
def __init__(self, *args, **kwargs):
super(OracleCompiler, self).__init__(*args, **kwargs)
- # we have to modify SELECT objects a little bit, so store state here
- self._select_state = {}
+ self.__wheres = {}
def default_from(self):
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
@@ -480,49 +519,46 @@ class OracleCompiler(ansisql.ANSICompiler):
def apply_function_parens(self, func):
return len(func.clauses) > 0
- def visit_join(self, join):
+ def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return ansisql.ANSICompiler.visit_join(self, join)
-
- self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
- where = self.wheres.get(join.left, None)
+ return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+
+ (where, parentjoin) = self.__wheres.get(join, (None, None))
+
+ class VisitOn(sql.ClauseVisitor):
+ def visit_binary(s, binary):
+ if binary.operator == operator.eq:
+ if binary.left.table is join.right:
+ binary.left = _OuterJoinColumn(binary.left)
+ elif binary.right.table is join.right:
+ binary.right = _OuterJoinColumn(binary.right)
+
if where is not None:
- self.wheres[join] = sql.and_(where, join.onclause)
+ self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
else:
- self.wheres[join] = join.onclause
-# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
- self.strings[join] = self.froms[join]
-
- if join.isouter:
- # if outer join, push on the right side table as the current "outertable"
- self._outertable = join.right
-
- # now re-visit the onclause, which will be used as a where clause
- # (the first visit occured via the Join object itself right before it called visit_join())
- self.traverse(join.onclause)
-
- self._outertable = None
-
- self.wheres[join].accept_visitor(self)
+ self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join)
- def visit_insert_sequence(self, column, sequence, parameters):
- """This is the `sequence` equivalent to ``ANSICompiler``'s
- `visit_insert_column_default` which ensures that the column is
- present in the generated column list.
- """
-
- parameters.setdefault(column.key, None)
+ return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+ def get_whereclause(self, f):
+ if f in self.__wheres:
+ return self.__wheres[f][0]
+ else:
+ return None
+
+ def visit_outer_join_column(self, vc):
+ return self.process(vc.column) + "(+)"
+
+ def uses_sequences_for_inserts(self):
+ return True
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
-
- self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
- self.strings[alias] = self.get_str(alias.original)
-
- def visit_column(self, column):
- ansisql.ANSICompiler.visit_column(self, column)
- if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
- self.strings[column] = self.strings[column] + "(+)"
+
+ if asfrom:
+ return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name
+ else:
+ return self.process(alias.original, **kwargs)
def visit_insert(self, insert):
"""``INSERT`` s are required to have the primary keys be explicitly present.
@@ -539,76 +575,35 @@ class OracleCompiler(ansisql.ANSICompiler):
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+ pass
- if getattr(select, '_oracle_visit', False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_compound_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- select._oracle_visit = True
- # to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
- if not orderby:
- orderby = select.oid_column
- self.traverse(orderby)
- orderby = self.strings[orderby]
- class SelectVisitor(sql.NoColumnVisitor):
- def visit_select(self, select):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- SelectVisitor().traverse(select)
- limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
- else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
- else:
- ansisql.ANSICompiler.visit_compound_select(self, select)
-
- def visit_select(self, select):
+ def visit_select(self, select, **kwargs):
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
- # TODO: put a real copy-container on Select and copy, or somehow make this
- # not modify the Select statement
- if self._select_state.get((select, 'visit'), False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- self._select_state[(select, 'visit')] = True
+ if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
# to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
+ orderby = self.process(select._order_by_clause)
if not orderby:
orderby = select.oid_column
self.traverse(orderby)
- orderby = self.strings[orderby]
- if not hasattr(select, '_oracle_visit'):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- select._oracle_visit = True
+ orderby = self.process(orderby)
+
+ oldselect = select
+ select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
+ select._oracle_visit = True
+
limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
+ if select._offset is not None:
+ limitselect.append_whereclause("ora_rn>%d" % select._offset)
+ if select._limit is not None:
+ limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
+ limitselect.append_whereclause("ora_rn<=%d" % select._limit)
+ return self.process(limitselect)
else:
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""
@@ -619,12 +614,6 @@ class OracleCompiler(ansisql.ANSICompiler):
else:
return super(OracleCompiler, self).for_update_clause(select)
- def visit_binary(self, binary):
- if binary.operator == '%':
- self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right)))
- else:
- return ansisql.ANSICompiler.visit_binary(self, binary)
-
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -639,22 +628,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class OracleSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["DUAL"]).compile(engine=self.connection)
- return self.connection.execute_compiled(c).scalar()
+ c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
+ return self.connection.execute(c).scalar()
def visit_sequence(self, seq):
- return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
+ return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
dialect = OracleDialect
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index d3726fc1f..b192c4778 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,12 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, string, types, re, random, warnings
+import re, random, warnings, operator
-from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy import sql, schema, ansisql, exceptions
from sqlalchemy.engine import base, default
import sqlalchemy.types as sqltypes
from sqlalchemy.databases import information_schema as ischema
+from decimal import Decimal
try:
import mx.DateTime.DateTime as mxDateTime
@@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric):
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class PGFloat(sqltypes.Float):
def get_col_spec(self):
if not self.precision:
@@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float):
else:
return "FLOAT(%(precision)s)" % {'precision': self.precision}
+
class PGInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
@@ -47,74 +58,15 @@ class PGBigInteger(PGInteger):
def get_col_spec(self):
return "BIGINT"
-class PG2DateTime(sqltypes.DateTime):
+class PGDateTime(sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-class PG1DateTime(sqltypes.DateTime):
- def convert_bind_param(self, value, dialect):
- if value is not None:
- if isinstance(value, datetime.datetime):
- seconds = float(str(value.second) + "."
- + str(value.microsecond))
- mx_datetime = mxDateTime(value.year, value.month, value.day,
- value.hour, value.minute,
- seconds)
- return dialect.dbapi.TimestampFromMx(mx_datetime)
- return dialect.dbapi.TimestampFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- if value is None:
- return None
- second_parts = str(value.second).split(".")
- seconds = int(second_parts[0])
- microseconds = int(float(second_parts[1]))
- return datetime.datetime(value.year, value.month, value.day,
- value.hour, value.minute, seconds,
- microseconds)
-
- def get_col_spec(self):
- return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG2Date(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
-class PG1Date(sqltypes.Date):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return dialect.dbapi.DateFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGDate(sqltypes.Date):
def get_col_spec(self):
return "DATE"
-class PG2Time(sqltypes.Time):
- def get_col_spec(self):
- return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG1Time(sqltypes.Time):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return psycopg.TimeFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGTime(sqltypes.Time):
def get_col_spec(self):
return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
@@ -142,28 +94,55 @@ class PGBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
-pg2_colspecs = {
+class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
+ def __init__(self, item_type):
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+
+ def dialect_impl(self, dialect):
+ impl = self.__class__.__new__(self.__class__)
+ impl.__dict__.update(self.__dict__)
+ impl.item_type = self.item_type.dialect_impl(dialect)
+ return impl
+ def convert_bind_param(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, (list,tuple)):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_bind_param(item, dialect)
+ return [convert_item(item) for item in value]
+ def convert_result_value(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, list):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_result_value(item, dialect)
+ # Could specialcase when item_type.convert_result_value is the default identity func
+ return [convert_item(item) for item in value]
+ def get_col_spec(self):
+ return self.item_type.get_col_spec() + '[]'
+
+colspecs = {
sqltypes.Integer : PGInteger,
sqltypes.Smallinteger : PGSmallInteger,
sqltypes.Numeric : PGNumeric,
sqltypes.Float : PGFloat,
- sqltypes.DateTime : PG2DateTime,
- sqltypes.Date : PG2Date,
- sqltypes.Time : PG2Time,
+ sqltypes.DateTime : PGDateTime,
+ sqltypes.Date : PGDate,
+ sqltypes.Time : PGTime,
sqltypes.String : PGString,
sqltypes.Binary : PGBinary,
sqltypes.Boolean : PGBoolean,
sqltypes.TEXT : PGText,
sqltypes.CHAR: PGChar,
}
-pg1_colspecs = pg2_colspecs.copy()
-pg1_colspecs.update({
- sqltypes.DateTime : PG1DateTime,
- sqltypes.Date : PG1Date,
- sqltypes.Time : PG1Time
- })
-
-pg2_ischema_names = {
+
+ischema_names = {
'integer' : PGInteger,
'bigint' : PGBigInteger,
'smallint' : PGSmallInteger,
@@ -175,24 +154,17 @@ pg2_ischema_names = {
'real' : PGFloat,
'inet': PGInet,
'double precision' : PGFloat,
- 'timestamp' : PG2DateTime,
- 'timestamp with time zone' : PG2DateTime,
- 'timestamp without time zone' : PG2DateTime,
- 'time with time zone' : PG2Time,
- 'time without time zone' : PG2Time,
- 'date' : PG2Date,
- 'time': PG2Time,
+ 'timestamp' : PGDateTime,
+ 'timestamp with time zone' : PGDateTime,
+ 'timestamp without time zone' : PGDateTime,
+ 'time with time zone' : PGTime,
+ 'time without time zone' : PGTime,
+ 'date' : PGDate,
+ 'time': PGTime,
'bytea' : PGBinary,
'boolean' : PGBoolean,
'interval':PGInterval,
}
-pg1_ischema_names = pg2_ischema_names.copy()
-pg1_ischema_names.update({
- 'timestamp with time zone' : PG1DateTime,
- 'timestamp without time zone' : PG1DateTime,
- 'date' : PG1Date,
- 'time' : PG1Time
- })
def descriptor():
return {'name':'postgres',
@@ -206,11 +178,11 @@ def descriptor():
class PGExecutionContext(default.DefaultExecutionContext):
- def is_select(self):
- return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
-
+ def _is_server_side(self):
+ return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I)
+
def create_cursor(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c" + hex(random.randint(0, 65535))[2:]
@@ -219,7 +191,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return self.connection.connection.cursor()
def get_result_proxy(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
return base.BufferedRowResultProxy(self)
else:
return base.ResultProxy(self)
@@ -242,31 +214,18 @@ class PGDialect(ansisql.ANSIDialect):
ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
- if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
- self.version = 2
- else:
- self.version = 1
self.use_information_schema = use_information_schema
self.paramstyle = 'pyformat'
def dbapi(cls):
- try:
- import psycopg2 as psycopg
- except ImportError, e:
- try:
- import psycopg
- except ImportError, e2:
- raise e
+ import psycopg2 as psycopg
return psycopg
dbapi = classmethod(dbapi)
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
if opts.has_key('port'):
- if self.version == 2:
- opts['port'] = int(opts['port'])
- else:
- opts['port'] = str(opts['port'])
+ opts['port'] = int(opts['port'])
opts.update(url.query)
return ([], opts)
@@ -278,10 +237,7 @@ class PGDialect(ansisql.ANSIDialect):
return 63
def type_descriptor(self, typeobj):
- if self.version == 2:
- return sqltypes.adapt_type(typeobj, pg2_colspecs)
- else:
- return sqltypes.adapt_type(typeobj, pg1_colspecs)
+ return sqltypes.adapt_type(typeobj, colspecs)
def compiler(self, statement, bindparams, **kwargs):
return PGCompiler(self, statement, bindparams, **kwargs)
@@ -292,8 +248,36 @@ class PGDialect(ansisql.ANSIDialect):
def schemadropper(self, *args, **kwargs):
return PGSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, connection, **kwargs):
- return PGDefaultRunner(connection, **kwargs)
+ def do_begin_twophase(self, connection, xid):
+ self.do_begin(connection.connection)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
+ # Must find out a way how to make the dbapi not open a transaction.
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
+ return [row[0] for row in resultset]
+
+ def defaultrunner(self, context, **kwargs):
+ return PGDefaultRunner(context, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
@@ -351,14 +335,9 @@ class PGDialect(ansisql.ANSIDialect):
else:
return False
- def reflecttable(self, connection, table):
- if self.version == 2:
- ischema_names = pg2_ischema_names
- else:
- ischema_names = pg1_ischema_names
-
+ def reflecttable(self, connection, table, include_columns):
if self.use_information_schema:
- ischema.reflecttable(connection, table, ischema_names)
+ ischema.reflecttable(connection, table, include_columns, ischema_names)
else:
preparer = self.identifier_preparer
if table.schema is not None:
@@ -387,7 +366,7 @@ class PGDialect(ansisql.ANSIDialect):
ORDER BY a.attnum
""" % schema_where_clause
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
+ s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
c = connection.execute(s, table_name=table.name,
schema=table.schema)
rows = c.fetchall()
@@ -398,9 +377,13 @@ class PGDialect(ansisql.ANSIDialect):
domains = self._load_domains(connection)
for name, format_type, default, notnull, attnum, table_oid in rows:
+ if include_columns and name not in include_columns:
+ continue
+
## strip (30) from character varying(30)
- attype = re.search('([^\(]+)', format_type).group(1)
+ attype = re.search('([^\([]+)', format_type).group(1)
nullable = not notnull
+ is_array = format_type.endswith('[]')
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
@@ -453,6 +436,8 @@ class PGDialect(ansisql.ANSIDialect):
if coltype:
coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = PGArray(coltype)
else:
warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
coltype = sqltypes.NULLTYPE
@@ -517,7 +502,6 @@ class PGDialect(ansisql.ANSIDialect):
table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
def _load_domains(self, connection):
-
## Load data types for domains:
SQL_DOMAINS = """
SELECT t.typname as "name",
@@ -554,49 +538,46 @@ class PGDialect(ansisql.ANSIDialect):
-
class PGCompiler(ansisql.ANSICompiler):
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : '%%'
+ }
+ )
- def visit_insert_sequence(self, column, sequence, parameters):
- """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures
- that the column is present in the generated column list"""
- parameters.setdefault(column.key, None)
+ def uses_sequences_for_inserts(self):
+ return True
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT ALL"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def visit_select_precolumns(self, select):
- if select.distinct:
- if type(select.distinct) == bool:
+ def get_select_precolumns(self, select):
+ if select._distinct:
+ if type(select._distinct) == bool:
return "DISTINCT "
- if type(select.distinct) == list:
+ if type(select._distinct) == list:
dist_set = "DISTINCT ON ("
- for col in select.distinct:
+ for col in select._distinct:
dist_set += self.strings[col] + ", "
dist_set = dist_set[:-2] + ") "
return dist_set
- return "DISTINCT ON (" + str(select.distinct) + ") "
+ return "DISTINCT ON (" + str(select._distinct) + ") "
else:
return ""
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- elif binary.operator == '%':
- return '%%'
+ def for_update_clause(self, select):
+ if select.for_update == 'nowait':
+ return " FOR UPDATE NOWAIT"
else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
+ return super(PGCompiler, self).for_update_clause(select)
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -617,13 +598,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class PGSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
@@ -632,7 +613,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
- return self.connection.execute_text("select %s" % column.default.arg).scalar()
+ return self.connection.execute("select %s" % column.default.arg).scalar()
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
@@ -641,7 +622,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.connection.execute_text(exc).scalar()
+ return self.connection.execute(exc).scalar()
return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 816b1b76a..725ea23e2 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -5,9 +5,9 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, types, re
+import re
-from sqlalchemy import sql, engine, schema, ansisql, exceptions, pool, PassiveDefault
+from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault
import sqlalchemy.engine.default as default
import sqlalchemy.types as sqltypes
import datetime,time, warnings
@@ -126,6 +126,7 @@ colspecs = {
pragma_names = {
'INTEGER' : SLInteger,
+ 'INT' : SLInteger,
'SMALLINT' : SLSmallInteger,
'VARCHAR' : SLString,
'CHAR' : SLChar,
@@ -150,8 +151,9 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
- super(SQLiteExecutionContext, self).post_exec()
+
+ def is_select(self):
+ return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None
class SQLiteDialect(ansisql.ANSIDialect):
@@ -233,7 +235,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
return (row is not None)
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {})
found_table = False
while True:
@@ -244,6 +246,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
found_table = True
(name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5])
name = re.sub(r'^\"|\"$', '', name)
+ if include_columns and name not in include_columns:
+ continue
match = re.match(r'(\w+)(\(.*?\))?', type)
if match:
coltype = match.group(1)
@@ -253,7 +257,12 @@ class SQLiteDialect(ansisql.ANSIDialect):
args = ''
#print "coltype: " + repr(coltype) + " args: " + repr(args)
- coltype = pragma_names.get(coltype, SLString)
+ try:
+ coltype = pragma_names[coltype]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
+ coltype = sqltypes.NULLTYPE
+
if args is not None:
args = re.findall(r'(\d+)', args)
#print "args! " +repr(args)
@@ -318,21 +327,21 @@ class SQLiteDialect(ansisql.ANSIDialect):
class SQLiteCompiler(ansisql.ANSICompiler):
def visit_cast(self, cast):
if self.dialect.supports_cast:
- super(SQLiteCompiler, self).visit_cast(cast)
+ return super(SQLiteCompiler, self).visit_cast(cast)
else:
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
- self.strings[cast] = self.strings[cast.clause]
+ return self.process(cast.clause)
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
else:
text += " OFFSET 0"
return text
@@ -341,12 +350,6 @@ class SQLiteCompiler(ansisql.ANSICompiler):
# sqlite has no "FOR UPDATE" AFAICT
return ''
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 5a4865de0..50d03ea91 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -48,7 +48,6 @@ The package is represented among several individual modules, including:
from sqlalchemy import databases
from sqlalchemy.engine.base import *
from sqlalchemy.engine import strategies
-import re
def engine_descriptors():
"""Provide a listing of all the database implementations supported.
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index d0ca36515..fc4433a47 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -9,33 +9,10 @@ higher-level statement-construction, connection-management,
execution and result contexts."""
from sqlalchemy import exceptions, sql, schema, util, types, logging
-import StringIO, sys, re
+import StringIO, sys, re, random
-class ConnectionProvider(object):
- """Define an interface that returns raw Connection objects (or compatible)."""
-
- def get_connection(self):
- """Return a Connection or compatible object from a DBAPI which also contains a close() method.
-
- It is not defined what context this connection belongs to. It
- may be newly connected, returned from a pool, part of some
- other kind of context such as thread-local, or can be a fixed
- member of this object.
- """
-
- raise NotImplementedError()
-
- def dispose(self):
- """Release all resources corresponding to this ConnectionProvider.
-
- This includes any underlying connection pools.
- """
-
- raise NotImplementedError()
-
-
-class Dialect(sql.AbstractDialect):
+class Dialect(object):
"""Define the behavior of a specific database/DBAPI.
Any aspect of metadata definition, SQL query generation, execution,
@@ -70,11 +47,14 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def convert_compiled_params(self, parameters):
- """Build DBAPI execute arguments from a [sqlalchemy.sql#ClauseParameters] instance.
-
- Returns an array or dictionary suitable to pass directly to this ``Dialect`` instance's DBAPI's
- execute method.
+ def dbapi_type_map(self):
+ """return a mapping of DBAPI type objects present in this Dialect's DBAPI
+ mapped to TypeEngine implementations used by the dialect.
+
+ This is used to apply types to result sets based on the DBAPI types
+ present in cursor.description; it only takes effect for result sets against
+ textual statements where no explicit typemap was present. Constructed SQL statements
+ always have type information explicitly embedded.
"""
raise NotImplementedError()
@@ -149,11 +129,11 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def defaultrunner(self, connection, **kwargs):
+ def defaultrunner(self, execution_context):
"""Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
+ execution_context
+ a [sqlalchemy.engine#ExecutionContext] to use for statement execution
"""
@@ -168,11 +148,12 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns=None):
"""Load table description from the database.
Given a [sqlalchemy.engine#Connection] and a [sqlalchemy.schema#Table] object, reflect its
- columns and properties from the database.
+ columns and properties from the database. If include_columns (a list or set) is specified, limit the autoload
+ to the given column names.
"""
raise NotImplementedError()
@@ -222,6 +203,46 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
+ def do_savepoint(self, connection, name):
+ """Create a savepoint with the given name on a SQLAlchemy connection."""
+
+ raise NotImplementedError()
+
+ def do_rollback_to_savepoint(self, connection, name):
+ """Rollback a SQL Alchemy connection to the named savepoint."""
+
+ raise NotImplementedError()
+
+ def do_release_savepoint(self, connection, name):
+ """Release the named savepoint on a SQL Alchemy connection."""
+
+ raise NotImplementedError()
+
+ def do_begin_twophase(self, connection, xid):
+ """Begin a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_prepare_twophase(self, connection, xid):
+ """Prepare a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ """Rollback a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ """Commit a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_recover_twophase(self, connection):
+ """Recover list of uncommited prepared two phase transaction identifiers on the given connection."""
+
+ raise NotImplementedError()
+
def do_executemany(self, cursor, statement, parameters):
"""Provide an implementation of *cursor.executemany(statement, parameters)*."""
@@ -266,19 +287,18 @@ class ExecutionContext(object):
compiled
if passed to constructor, sql.Compiled object being executed
- compiled_parameters
- if passed to constructor, sql.ClauseParameters object
-
statement
string version of the statement to be executed. Is either
passed to the constructor, or must be created from the
sql.Compiled object by the time pre_exec() has completed.
parameters
- "raw" parameters suitable for direct execution by the
- dialect. Either passed to the constructor, or must be
- created from the sql.ClauseParameters object by the time
- pre_exec() has completed.
+ bind parameters passed to the execute() method. for
+ compiled statements, this is a dictionary or list
+ of dictionaries. for textual statements, it should
+ be in a format suitable for the dialect's paramstyle
+ (i.e. dict or list of dicts for non positional,
+ list or list of lists/tuples for positional).
The Dialect should provide an ExecutionContext via the
@@ -288,24 +308,28 @@ class ExecutionContext(object):
"""
def create_cursor(self):
- """Return a new cursor generated this ExecutionContext's connection."""
+ """Return a new cursor generated from this ExecutionContext's connection.
+
+ Some dialects may wish to change the behavior of connection.cursor(),
+ such as postgres which may return a PG "server side" cursor.
+ """
raise NotImplementedError()
- def pre_exec(self):
+ def pre_execution(self):
"""Called before an execution of a compiled statement.
- If compiled and compiled_parameters were passed to this
+ If a compiled statement was passed to this
ExecutionContext, the `statement` and `parameters` datamembers
must be initialized after this statement is complete.
"""
raise NotImplementedError()
- def post_exec(self):
+ def post_execution(self):
"""Called after the execution of a compiled statement.
- If compiled was passed to this ExecutionContext,
+ If a compiled statement was passed to this ExecutionContext,
the `last_insert_ids`, `last_inserted_params`, etc.
datamembers should be available after this method
completes.
@@ -313,8 +337,11 @@ class ExecutionContext(object):
raise NotImplementedError()
- def get_result_proxy(self):
- """return a ResultProxy corresponding to this ExecutionContext."""
+ def result(self):
+ """return a result object corresponding to this ExecutionContext.
+
+ Returns a ResultProxy."""
+
raise NotImplementedError()
def get_rowcount(self):
@@ -361,8 +388,88 @@ class ExecutionContext(object):
raise NotImplementedError()
+class Compiled(object):
+ """Represent a compiled SQL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to their underlying database dialect, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
+
+ def __init__(self, dialect, statement, parameters, bind=None):
+ """Construct a new ``Compiled`` object.
+
+ statement
+ ``ClauseElement`` to be compiled.
+
+ parameters
+ Optional dictionary indicating a set of bind parameters
+ specified with this ``Compiled`` object. These parameters
+ are the *default* values corresponding to the
+ ``ClauseElement``'s ``_BindParamClauses`` when the
+ ``Compiled`` is executed. In the case of an ``INSERT`` or
+ ``UPDATE`` statement, these parameters will also result in
+ the creation of new ``_BindParamClause`` objects for each
+ key and will also affect the generated column list in an
+ ``INSERT`` statement and the ``SET`` clauses of an
+ ``UPDATE`` statement. The keys of the parameter dictionary
+ can either be the string names of columns or
+ ``_ColumnClause`` objects.
+
+ bind
+ Optional Engine or Connection to compile this statement against.
+ """
+ self.dialect = dialect
+ self.statement = statement
+ self.parameters = parameters
+ self.bind = bind
+ self.can_execute = statement.supports_execution()
+
+ def compile(self):
+ """Produce the internal string representation of this element."""
+
+ raise NotImplementedError()
+
+ def __str__(self):
+ """Return the string text of the generated SQL statement."""
+
+ raise NotImplementedError()
+
+ def get_params(self, **params):
+ """Deprecated. use construct_params(). (supports unicode names)
+ """
+
+ return self.construct_params(params)
+
+ def construct_params(self, params):
+ """Return the bind params for this compiled object.
+
+ params is a dict of string/object pairs whos
+ values will override bind values compiled in
+ to the statement.
+ """
+ raise NotImplementedError()
+
+ def execute(self, *multiparams, **params):
+ """Execute this compiled object."""
+
+ e = self.bind
+ if e is None:
+ raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
+ return e._execute_compiled(self, multiparams, params)
+
+ def scalar(self, *multiparams, **params):
+ """Execute this compiled object and return the result's scalar value."""
+
+ return self.execute(*multiparams, **params).scalar()
+
-class Connectable(sql.Executor):
+class Connectable(object):
"""Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
def contextual_connect(self):
@@ -401,6 +508,7 @@ class Connection(Connectable):
self.__connection = connection or engine.raw_connection()
self.__transaction = None
self.__close_with_result = close_with_result
+ self.__savepoint_seq = 0
def _get_connection(self):
try:
@@ -408,13 +516,18 @@ class Connection(Connectable):
except AttributeError:
raise exceptions.InvalidRequestError("This Connection is closed")
+ def _branch(self):
+ """return a new Connection which references this Connection's
+ engine and connection; but does not have close_with_result enabled."""
+
+ return Connection(self.__engine, self.__connection)
+
engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
-
- def _create_transaction(self, parent):
- return Transaction(self, parent)
+ properties = property(lambda s: s._get_connection().properties,
+ doc="A set of per-DBAPI connection properties.")
def connect(self):
"""connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly."""
@@ -448,12 +561,34 @@ class Connection(Connectable):
self.__connection.detach()
- def begin(self):
+ def begin(self, nested=False):
if self.__transaction is None:
- self.__transaction = self._create_transaction(None)
- return self.__transaction
+ self.__transaction = RootTransaction(self)
+ elif nested:
+ self.__transaction = NestedTransaction(self, self.__transaction)
else:
- return self._create_transaction(self.__transaction)
+ return Transaction(self, self.__transaction)
+ return self.__transaction
+
+ def begin_nested(self):
+ return self.begin(nested=True)
+
+ def begin_twophase(self, xid=None):
+ if self.__transaction is not None:
+ raise exceptions.InvalidRequestError("Cannot start a two phase transaction when a transaction is already started.")
+ if xid is None:
+ xid = "_sa_%032x" % random.randint(0,2**128)
+ self.__transaction = TwoPhaseTransaction(self, xid)
+ return self.__transaction
+
+ def recover_twophase(self):
+ return self.__engine.dialect.do_recover_twophase(self)
+
+ def rollback_prepared(self, xid, recover=False):
+ self.__engine.dialect.do_rollback_twophase(self, xid, recover=recover)
+
+ def commit_prepared(self, xid, recover=False):
+ self.__engine.dialect.do_commit_twophase(self, xid, recover=recover)
def in_transaction(self):
return self.__transaction is not None
@@ -485,6 +620,45 @@ class Connection(Connectable):
raise exceptions.SQLError(None, None, e)
self.__transaction = None
+ def _savepoint_impl(self, name=None):
+ if name is None:
+ self.__savepoint_seq += 1
+ name = '__sa_savepoint_%s' % self.__savepoint_seq
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_savepoint(self, name)
+ return name
+
+ def _rollback_to_savepoint_impl(self, name, context):
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_rollback_to_savepoint(self, name)
+ self.__transaction = context
+
+ def _release_savepoint_impl(self, name, context):
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_release_savepoint(self, name)
+ self.__transaction = context
+
+ def _begin_twophase_impl(self, xid):
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_begin_twophase(self, xid)
+
+ def _prepare_twophase_impl(self, xid):
+ if self.__connection.is_valid:
+ assert isinstance(self.__transaction, TwoPhaseTransaction)
+ self.__engine.dialect.do_prepare_twophase(self, xid)
+
+ def _rollback_twophase_impl(self, xid, is_prepared):
+ if self.__connection.is_valid:
+ assert isinstance(self.__transaction, TwoPhaseTransaction)
+ self.__engine.dialect.do_rollback_twophase(self, xid, is_prepared)
+ self.__transaction = None
+
+ def _commit_twophase_impl(self, xid, is_prepared):
+ if self.__connection.is_valid:
+ assert isinstance(self.__transaction, TwoPhaseTransaction)
+ self.__engine.dialect.do_commit_twophase(self, xid, is_prepared)
+ self.__transaction = None
+
def _autocommit(self, statement):
"""When no Transaction is present, this is called after executions to provide "autocommit" behavior."""
# TODO: have the dialect determine if autocommit can be set on the connection directly without this
@@ -495,7 +669,7 @@ class Connection(Connectable):
def _autorollback(self):
if not self.in_transaction():
self._rollback_impl()
-
+
def close(self):
try:
c = self.__connection
@@ -514,74 +688,66 @@ class Connection(Connectable):
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
if c in Connection.executors:
- return Connection.executors[c](self, object, *multiparams, **params)
+ return Connection.executors[c](self, object, multiparams, params)
else:
raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
- def execute_default(self, default, **kwargs):
- return default.accept_visitor(self.__engine.dialect.defaultrunner(self))
+ def _execute_default(self, default, multiparams=None, params=None):
+ return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
+
+ def _execute_text(self, statement, multiparams, params):
+ parameters = self.__distill_params(multiparams, params)
+ context = self.__create_execution_context(statement=statement, parameters=parameters)
+ self.__execute_raw(context)
+ return context.result()
- def execute_text(self, statement, *multiparams, **params):
- if len(multiparams) == 0:
+ def __distill_params(self, multiparams, params):
+ if multiparams is None or len(multiparams) == 0:
parameters = params or None
- elif len(multiparams) == 1 and (isinstance(multiparams[0], list) or isinstance(multiparams[0], tuple) or isinstance(multiparams[0], dict)):
+ elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)):
parameters = multiparams[0]
else:
parameters = list(multiparams)
- context = self._create_execution_context(statement=statement, parameters=parameters)
- self._execute_raw(context)
- return context.get_result_proxy()
-
- def _params_to_listofdicts(self, *multiparams, **params):
- if len(multiparams) == 0:
- return [params]
- elif len(multiparams) == 1:
- if multiparams[0] == None:
- return [{}]
- elif isinstance (multiparams[0], list) or isinstance (multiparams[0], tuple):
- return multiparams[0]
- else:
- return [multiparams[0]]
- else:
- return multiparams
-
- def execute_function(self, func, *multiparams, **params):
- return self.execute_clauseelement(func.select(), *multiparams, **params)
+ return parameters
+
+ def _execute_function(self, func, multiparams, params):
+ return self._execute_clauseelement(func.select(), multiparams, params)
- def execute_clauseelement(self, elem, *multiparams, **params):
- executemany = len(multiparams) > 0
+ def _execute_clauseelement(self, elem, multiparams=None, params=None):
+ executemany = multiparams is not None and len(multiparams) > 0
if executemany:
param = multiparams[0]
else:
param = params
- return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params)
+ return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param), multiparams, params)
- def execute_compiled(self, compiled, *multiparams, **params):
+ def _execute_compiled(self, compiled, multiparams=None, params=None):
"""Execute a sql.Compiled object."""
if not compiled.can_execute:
raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
- parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
- if len(parameters) == 1:
- parameters = parameters[0]
- context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters)
- context.pre_exec()
- self._execute_raw(context)
- context.post_exec()
- return context.get_result_proxy()
-
- def _create_execution_context(self, **kwargs):
+
+ params = self.__distill_params(multiparams, params)
+ context = self.__create_execution_context(compiled=compiled, parameters=params)
+
+ context.pre_execution()
+ self.__execute_raw(context)
+ context.post_execution()
+ return context.result()
+
+ def __create_execution_context(self, **kwargs):
return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
- def _execute_raw(self, context):
- self.__engine.logger.info(context.statement)
- self.__engine.logger.info(repr(context.parameters))
- if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], tuple) or isinstance(context.parameters[0], dict)):
- self._executemany(context)
+ def __execute_raw(self, context):
+ if logging.is_info_enabled(self.__engine.logger):
+ self.__engine.logger.info(context.statement)
+ self.__engine.logger.info(repr(context.parameters))
+ if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
+ self.__executemany(context)
else:
- self._execute(context)
+ self.__execute(context)
self._autocommit(context.statement)
- def _execute(self, context):
+ def __execute(self, context):
if context.parameters is None:
if context.dialect.positional:
context.parameters = ()
@@ -592,19 +758,19 @@ class Connection(Connectable):
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
- self.engine.connection_provider.dispose()
+ self.engine.dispose()
self._autorollback()
if self.__close_with_result:
self.close()
raise exceptions.SQLError(context.statement, context.parameters, e)
- def _executemany(self, context):
+ def __executemany(self, context):
try:
context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
- self.engine.connection_provider.dispose()
+ self.engine.dispose()
self._autorollback()
if self.__close_with_result:
self.close()
@@ -612,11 +778,11 @@ class Connection(Connectable):
# poor man's multimethod/generic function thingy
executors = {
- sql._Function : execute_function,
- sql.ClauseElement : execute_clauseelement,
- sql.ClauseVisitor : execute_compiled,
- schema.SchemaItem:execute_default,
- str.__mro__[-2] : execute_text
+ sql._Function : _execute_function,
+ sql.ClauseElement : _execute_clauseelement,
+ sql.ClauseVisitor : _execute_compiled,
+ schema.SchemaItem:_execute_default,
+ str.__mro__[-2] : _execute_text
}
def create(self, entity, **kwargs):
@@ -629,10 +795,10 @@ class Connection(Connectable):
return self.__engine.drop(entity, connection=self, **kwargs)
- def reflecttable(self, table, **kwargs):
+ def reflecttable(self, table, include_columns=None):
"""Reflect the columns in the given string table name from the database."""
- return self.__engine.reflecttable(table, connection=self, **kwargs)
+ return self.__engine.reflecttable(table, self, include_columns)
def default_schema_name(self):
return self.__engine.dialect.get_default_schema_name(self)
@@ -647,39 +813,90 @@ class Transaction(object):
"""
def __init__(self, connection, parent):
- self.__connection = connection
- self.__parent = parent or self
- self.__is_active = True
- if self.__parent is self:
- self.__connection._begin_impl()
+ self._connection = connection
+ self._parent = parent or self
+ self._is_active = True
- connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction")
- is_active = property(lambda s:s.__is_active)
+ connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction")
+ is_active = property(lambda s:s._is_active)
def rollback(self):
- if not self.__parent.__is_active:
+ if not self._parent._is_active:
return
- if self.__parent is self:
- self.__connection._rollback_impl()
- self.__is_active = False
- else:
- self.__parent.rollback()
+ self._is_active = False
+ self._do_rollback()
+
+ def _do_rollback(self):
+ self._parent.rollback()
def commit(self):
- if not self.__parent.__is_active:
+ if not self._parent._is_active:
raise exceptions.InvalidRequestError("This transaction is inactive")
- if self.__parent is self:
- self.__connection._commit_impl()
- self.__is_active = False
+ self._is_active = False
+ self._do_commit()
+
+ def _do_commit(self):
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ if type is None and self._is_active:
+ self.commit()
+ else:
+ self.rollback()
+
+class RootTransaction(Transaction):
+ def __init__(self, connection):
+ super(RootTransaction, self).__init__(connection, None)
+ self._connection._begin_impl()
+
+ def _do_rollback(self):
+ self._connection._rollback_impl()
+
+ def _do_commit(self):
+ self._connection._commit_impl()
+
+class NestedTransaction(Transaction):
+ def __init__(self, connection, parent):
+ super(NestedTransaction, self).__init__(connection, parent)
+ self._savepoint = self._connection._savepoint_impl()
+
+ def _do_rollback(self):
+ self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent)
+
+ def _do_commit(self):
+ self._connection._release_savepoint_impl(self._savepoint, self._parent)
+
+class TwoPhaseTransaction(Transaction):
+ def __init__(self, connection, xid):
+ super(TwoPhaseTransaction, self).__init__(connection, None)
+ self._is_prepared = False
+ self.xid = xid
+ self._connection._begin_twophase_impl(self.xid)
+
+ def prepare(self):
+ if not self._parent._is_active:
+ raise exceptions.InvalidRequestError("This transaction is inactive")
+ self._connection._prepare_twophase_impl(self.xid)
+ self._is_prepared = True
+
+ def _do_rollback(self):
+ self._connection._rollback_twophase_impl(self.xid, self._is_prepared)
+
+ def commit(self):
+ self._connection._commit_twophase_impl(self.xid, self._is_prepared)
class Engine(Connectable):
"""
- Connects a ConnectionProvider, a Dialect and a CompilerFactory together to
+ Connects a Pool, a Dialect and a CompilerFactory together to
provide a default implementation of SchemaEngine.
"""
- def __init__(self, connection_provider, dialect, echo=None):
- self.connection_provider = connection_provider
+ def __init__(self, pool, dialect, url, echo=None):
+ self.pool = pool
+ self.url = url
self._dialect=dialect
self.echo = echo
self.logger = logging.instance_logger(self)
@@ -688,10 +905,13 @@ class Engine(Connectable):
engine = property(lambda s:s)
dialect = property(lambda s:s._dialect, doc="the [sqlalchemy.engine#Dialect] in use by this engine.")
echo = logging.echo_property()
- url = property(lambda s:s.connection_provider.url, doc="The [sqlalchemy.engine.url#URL] object representing this ``Engine`` object's datasource.")
+
+ def __repr__(self):
+ return 'Engine(%s)' % str(self.url)
def dispose(self):
- self.connection_provider.dispose()
+ self.pool.dispose()
+ self.pool = self.pool.recreate()
def create(self, entity, connection=None, **kwargs):
"""Create a table or index within this engine's database connection given a schema.Table object."""
@@ -703,22 +923,22 @@ class Engine(Connectable):
self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
- def execute_default(self, default, **kwargs):
+ def _execute_default(self, default):
connection = self.contextual_connect()
try:
- return connection.execute_default(default, **kwargs)
+ return connection._execute_default(default)
finally:
connection.close()
def _func(self):
- return sql._FunctionGenerator(engine=self)
+ return sql._FunctionGenerator(bind=self)
func = property(_func)
def text(self, text, *args, **kwargs):
"""Return a sql.text() object for performing literal queries."""
- return sql.text(text, engine=self, *args, **kwargs)
+ return sql.text(text, bind=self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
if connection is None:
@@ -726,7 +946,7 @@ class Engine(Connectable):
else:
conn = connection
try:
- element.accept_visitor(visitorcallable(conn, **kwargs))
+ visitorcallable(conn, **kwargs).traverse(element)
finally:
if connection is None:
conn.close()
@@ -775,12 +995,12 @@ class Engine(Connectable):
def scalar(self, statement, *multiparams, **params):
return self.execute(statement, *multiparams, **params).scalar()
- def execute_compiled(self, compiled, *multiparams, **params):
+ def _execute_compiled(self, compiled, multiparams, params):
connection = self.contextual_connect(close_with_result=True)
- return connection.execute_compiled(compiled, *multiparams, **params)
+ return connection._execute_compiled(compiled, multiparams, params)
def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, engine=self, **kwargs)
+ return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
@@ -795,7 +1015,7 @@ class Engine(Connectable):
return Connection(self, close_with_result=close_with_result, **kwargs)
- def reflecttable(self, table, connection=None):
+ def reflecttable(self, table, connection=None, include_columns=None):
"""Given a Table object, reflects its columns and properties from the database."""
if connection is None:
@@ -803,7 +1023,7 @@ class Engine(Connectable):
else:
conn = connection
try:
- self.dialect.reflecttable(conn, table)
+ self.dialect.reflecttable(conn, table, include_columns)
finally:
if connection is None:
conn.close()
@@ -814,7 +1034,7 @@ class Engine(Connectable):
def raw_connection(self):
"""Return a DBAPI connection."""
- return self.connection_provider.get_connection()
+ return self.pool.connect()
def log(self, msg):
"""Log a message using this SQLEngine's logger stream."""
@@ -858,28 +1078,42 @@ class ResultProxy(object):
self.closed = False
self.cursor = context.cursor
self.__echo = logging.is_debug_enabled(context.engine.logger)
- self._init_metadata()
-
- rowcount = property(lambda s:s.context.get_rowcount())
- connection = property(lambda s:s.context.connection)
+ if context.is_select():
+ self._init_metadata()
+ self._rowcount = None
+ else:
+ self._rowcount = context.get_rowcount()
+ self.close()
+
+ connection = property(lambda self:self.context.connection)
+ def _get_rowcount(self):
+ if self._rowcount is not None:
+ return self._rowcount
+ else:
+ return self.context.get_rowcount()
+ rowcount = property(_get_rowcount)
lastrowid = property(lambda s:s.cursor.lastrowid)
+ out_parameters = property(lambda s:s.context.out_parameters)
def _init_metadata(self):
if hasattr(self, '_ResultProxy__props'):
return
- self.__key_cache = {}
self.__props = {}
+ self._key_cache = self._create_key_cache()
self.__keys = []
metadata = self.cursor.description
if metadata is not None:
+ typemap = self.dialect.dbapi_type_map()
+
for i, item in enumerate(metadata):
# sqlite possibly prepending table name to colnames so strip
- colname = item[0].split('.')[-1]
+ colname = self.dialect.decode_result_columnname(item[0].split('.')[-1])
if self.context.typemap is not None:
- type = self.context.typemap.get(colname.lower(), types.NULLTYPE)
+ type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
else:
- type = types.NULLTYPE
+ type = typemap.get(item[1], types.NULLTYPE)
+
rec = (type, type.dialect_impl(self.dialect), i)
if rec[0] is None:
@@ -889,6 +1123,33 @@ class ResultProxy(object):
self.__keys.append(colname)
self.__props[i] = rec
+ if self.__echo:
+ self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata])))
+
+ def _create_key_cache(self):
+ # local copies to avoid circular ref against 'self'
+ props = self.__props
+ context = self.context
+ def lookup_key(key):
+ """Given a key, which could be a ColumnElement, string, etc.,
+ matches it to the appropriate key we got from the result set's
+ metadata; then cache it locally for quick re-access."""
+
+ if isinstance(key, int) and key in props:
+ rec = props[key]
+ elif isinstance(key, basestring) and key.lower() in props:
+ rec = props[key.lower()]
+ elif isinstance(key, sql.ColumnElement):
+ label = context.column_labels.get(key._label, key.name).lower()
+ if label in props:
+ rec = props[label]
+
+ if not "rec" in locals():
+ raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
+
+ return rec
+ return util.PopulateDict(lookup_key)
+
def close(self):
"""Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
@@ -904,38 +1165,12 @@ class ResultProxy(object):
self.cursor.close()
if self.connection.should_close_with_result:
self.connection.close()
-
- def _convert_key(self, key):
- """Convert and cache a key.
-
- Given a key, which could be a ColumnElement, string, etc.,
- matches it to the appropriate key we got from the result set's
- metadata; then cache it locally for quick re-access.
- """
-
- if key in self.__key_cache:
- return self.__key_cache[key]
- else:
- if isinstance(key, int) and key in self.__props:
- rec = self.__props[key]
- elif isinstance(key, basestring) and key.lower() in self.__props:
- rec = self.__props[key.lower()]
- elif isinstance(key, sql.ColumnElement):
- label = self.context.column_labels.get(key._label, key.name).lower()
- if label in self.__props:
- rec = self.__props[label]
-
- if not "rec" in locals():
- raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
-
- self.__key_cache[key] = rec
- return rec
keys = property(lambda s:s.__keys)
def _has_key(self, row, key):
try:
- self._convert_key(key)
+ self._key_cache[key]
return True
except KeyError:
return False
@@ -989,7 +1224,7 @@ class ResultProxy(object):
return self.context.supports_sane_rowcount()
def _get_col(self, row, key):
- rec = self._convert_key(key)
+ rec = self._key_cache[key]
return rec[1].convert_result_value(row[rec[2]], self.dialect)
def _fetchone_impl(self):
@@ -1101,7 +1336,7 @@ class BufferedColumnResultProxy(ResultProxy):
"""
def _get_col(self, row, key):
- rec = self._convert_key(key)
+ rec = self._key_cache[key]
return row[rec[2]]
def _process_row(self, row):
@@ -1152,6 +1387,9 @@ class RowProxy(object):
self.__parent.close()
+ def __contains__(self, key):
+ return self.__parent._has_key(self.__row, key)
+
def __iter__(self):
for i in range(0, len(self.__row)):
yield self.__parent._get_col(self.__row, i)
@@ -1168,7 +1406,11 @@ class RowProxy(object):
return self.__parent._has_key(self.__row, key)
def __getitem__(self, key):
- return self.__parent._get_col(self.__row, key)
+ if isinstance(key, slice):
+ indices = key.indices(len(self))
+ return tuple([self.__parent._get_col(self.__row, i) for i in range(*indices)])
+ else:
+ return self.__parent._get_col(self.__row, key)
def __getattr__(self, name):
try:
@@ -1226,19 +1468,22 @@ class DefaultRunner(schema.SchemaVisitor):
DefaultRunner to allow database-specific behavior.
"""
- def __init__(self, connection):
- self.connection = connection
- self.dialect = connection.dialect
+ def __init__(self, context):
+ self.context = context
+ # branch the connection so it doesnt close after result
+ self.connection = context.connection._branch()
+ dialect = property(lambda self:self.context.dialect)
+
def get_column_default(self, column):
if column.default is not None:
- return column.default.accept_visitor(self)
+ return self.traverse_single(column.default)
else:
return None
def get_column_onupdate(self, column):
if column.onupdate is not None:
- return column.onupdate.accept_visitor(self)
+ return self.traverse_single(column.onupdate)
else:
return None
@@ -1260,14 +1505,14 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def exec_default_sql(self, default):
- c = sql.select([default.arg]).compile(engine=self.connection)
- return self.connection.execute_compiled(c).scalar()
+ c = sql.select([default.arg]).compile(bind=self.connection)
+ return self.connection._execute_compiled(c).scalar()
def visit_column_onupdate(self, onupdate):
if isinstance(onupdate.arg, sql.ClauseElement):
return self.exec_default_sql(onupdate)
elif callable(onupdate.arg):
- return onupdate.arg()
+ return onupdate.arg(self.context)
else:
return onupdate.arg
@@ -1275,6 +1520,6 @@ class DefaultRunner(schema.SchemaVisitor):
if isinstance(default.arg, sql.ClauseElement):
return self.exec_default_sql(default)
elif callable(default.arg):
- return default.arg()
+ return default.arg(self.context)
else:
return default.arg
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 95f6566e3..962e2ab60 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -4,25 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Provide default implementations of per-dialect sqlalchemy.engine classes"""
-from sqlalchemy import schema, exceptions, util, sql, types
-import StringIO, sys, re
+from sqlalchemy import schema, exceptions, sql, types
+import sys, re
from sqlalchemy.engine import base
-"""Provide default implementations of the engine interfaces"""
-class PoolConnectionProvider(base.ConnectionProvider):
- def __init__(self, url, pool):
- self.url = url
- self._pool = pool
-
- def get_connection(self):
- return self._pool.connect()
-
- def dispose(self):
- self._pool.dispose()
- self._pool = self._pool.recreate()
-
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
@@ -33,7 +21,18 @@ class DefaultDialect(base.Dialect):
self._ischema = None
self.dbapi = dbapi
self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
-
+
+ def decode_result_columnname(self, name):
+ """decode a name found in cursor.description to a unicode object."""
+
+ return name.decode(self.encoding)
+
+ def dbapi_type_map(self):
+ # most DBAPIs have problems with this (such as, psycocpg2 types
+ # are unhashable). So far Oracle can return it.
+
+ return {}
+
def create_execution_context(self, **kwargs):
return DefaultExecutionContext(self, **kwargs)
@@ -88,6 +87,15 @@ class DefaultDialect(base.Dialect):
#print "ENGINE COMMIT ON ", connection.connection
connection.commit()
+
+ def do_savepoint(self, connection, name):
+ connection.execute(sql.SavepointClause(name))
+
+ def do_rollback_to_savepoint(self, connection, name):
+ connection.execute(sql.RollbackToSavepointClause(name))
+
+ def do_release_savepoint(self, connection, name):
+ connection.execute(sql.ReleaseSavepointClause(name))
def do_executemany(self, cursor, statement, parameters, **kwargs):
cursor.executemany(statement, parameters)
@@ -95,8 +103,8 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, connection):
- return base.DefaultRunner(connection)
+ def defaultrunner(self, context):
+ return base.DefaultRunner(context)
def is_disconnect(self, e):
return False
@@ -107,23 +115,6 @@ class DefaultDialect(base.Dialect):
paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
- def convert_compiled_params(self, parameters):
- executemany = parameters is not None and isinstance(parameters, list)
- # the bind params are a CompiledParams object. but all the DBAPI's hate
- # that object (or similar). so convert it to a clean
- # dictionary/list/tuple of dictionary/tuple of list
- if parameters is not None:
- if self.positional:
- if executemany:
- parameters = [p.get_raw_list() for p in parameters]
- else:
- parameters = parameters.get_raw_list()
- else:
- if executemany:
- parameters = [p.get_raw_dict() for p in parameters]
- else:
- parameters = parameters.get_raw_dict()
- return parameters
def _figure_paramstyle(self, paramstyle=None, default='named'):
if paramstyle is not None:
@@ -152,29 +143,38 @@ class DefaultDialect(base.Dialect):
ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
+ def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
self.dialect = dialect
self.connection = connection
self.compiled = compiled
- self.compiled_parameters = compiled_parameters
if compiled is not None:
self.typemap = compiled.typemap
self.column_labels = compiled.column_labels
self.statement = unicode(compiled)
- else:
+ if parameters is None:
+ self.compiled_parameters = compiled.construct_params({})
+ elif not isinstance(parameters, (list, tuple)):
+ self.compiled_parameters = compiled.construct_params(parameters)
+ else:
+ self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters]
+ if len(self.compiled_parameters) == 1:
+ self.compiled_parameters = self.compiled_parameters[0]
+ elif statement is not None:
self.typemap = self.column_labels = None
- self.parameters = self._encode_param_keys(parameters)
+ self.parameters = self.__encode_param_keys(parameters)
self.statement = statement
-
- if not dialect.supports_unicode_statements():
+ else:
+ self.statement = None
+
+ if self.statement is not None and not dialect.supports_unicode_statements():
self.statement = self.statement.encode(self.dialect.encoding)
self.cursor = self.create_cursor()
engine = property(lambda s:s.connection.engine)
- def _encode_param_keys(self, params):
+ def __encode_param_keys(self, params):
"""apply string encoding to the keys of dictionary-based bind parameters"""
if self.dialect.positional or self.dialect.supports_unicode_statements():
return params
@@ -189,16 +189,46 @@ class DefaultExecutionContext(base.ExecutionContext):
return [proc(d) for d in params]
else:
return proc(params)
+
+ def __convert_compiled_params(self, parameters):
+ executemany = parameters is not None and isinstance(parameters, list)
+ encode = not self.dialect.supports_unicode_statements()
+ # the bind params are a CompiledParams object. but all the DBAPI's hate
+ # that object (or similar). so convert it to a clean
+ # dictionary/list/tuple of dictionary/tuple of list
+ if parameters is not None:
+ if self.dialect.positional:
+ if executemany:
+ parameters = [p.get_raw_list() for p in parameters]
+ else:
+ parameters = parameters.get_raw_list()
+ else:
+ if executemany:
+ parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters]
+ else:
+ parameters = parameters.get_raw_dict(encode_keys=encode)
+ return parameters
def is_select(self):
- return re.match(r'SELECT', self.statement.lstrip(), re.I)
+ """return TRUE if the statement is expected to have result rows."""
+
+ return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None
def create_cursor(self):
return self.connection.connection.cursor()
-
+
+ def pre_execution(self):
+ self.pre_exec()
+
+ def post_execution(self):
+ self.post_exec()
+
+ def result(self):
+ return self.get_result_proxy()
+
def pre_exec(self):
self._process_defaults()
- self.parameters = self._encode_param_keys(self.dialect.convert_compiled_params(self.compiled_parameters))
+ self.parameters = self.__convert_compiled_params(self.compiled_parameters)
def post_exec(self):
pass
@@ -241,7 +271,7 @@ class DefaultExecutionContext(base.ExecutionContext):
inputsizes = []
for params in plist[0:1]:
for key in params.positional:
- typeengine = params.binds[key].type
+ typeengine = params.get_type(key)
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes.append(dbtype)
@@ -250,36 +280,23 @@ class DefaultExecutionContext(base.ExecutionContext):
inputsizes = {}
for params in plist[0:1]:
for key in params.keys():
- typeengine = params.binds[key].type
+ typeengine = params.get_type(key)
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes[key] = dbtype
self.cursor.setinputsizes(**inputsizes)
def _process_defaults(self):
- """``INSERT`` and ``UPDATE`` statements, when compiled, may
- have additional columns added to their ``VALUES`` and ``SET``
- lists corresponding to column defaults/onupdates that are
- present on the ``Table`` object (i.e. ``ColumnDefault``,
- ``Sequence``, ``PassiveDefault``). This method pre-execs
- those ``DefaultGenerator`` objects that require pre-execution
- and sets their values within the parameter list, and flags this
- ExecutionContext about ``PassiveDefault`` objects that may
- require post-fetching the row after it is inserted/updated.
-
- This method relies upon logic within the ``ANSISQLCompiler``
- in its `visit_insert` and `visit_update` methods that add the
- appropriate column clauses to the statement when its being
- compiled, so that these parameters can be bound to the
- statement.
- """
+ """generate default values for compiled insert/update statements,
+ and generate last_inserted_ids() collection."""
+ # TODO: cleanup
if self.compiled.isinsert:
if isinstance(self.compiled_parameters, list):
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
@@ -319,7 +336,7 @@ class DefaultExecutionContext(base.ExecutionContext):
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
# check the "onupdate" status of each column in the table
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index 7d85de9ad..0c59ee8eb 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -4,13 +4,11 @@ By default there are two, one which is the "thread-local" strategy,
one which is the "plain" strategy.
New strategies can be added via constructing a new EngineStrategy
-object which will add itself to the list of available strategies here,
-or replace one of the existing name. this can be accomplished via a
-mod; see the sqlalchemy/mods package for details.
+object which will add itself to the list of available strategies.
"""
-from sqlalchemy.engine import base, default, threadlocal, url
+from sqlalchemy.engine import base, threadlocal, url
from sqlalchemy import util, exceptions
from sqlalchemy import pool as poollib
@@ -92,8 +90,6 @@ class DefaultEngineStrategy(EngineStrategy):
else:
pool = pool
- provider = self.get_pool_provider(u, pool)
-
# create engine.
engineclass = self.get_engine_cls()
engine_args = {}
@@ -105,14 +101,11 @@ class DefaultEngineStrategy(EngineStrategy):
if len(kwargs):
raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__))
- return engineclass(provider, dialect, **engine_args)
+ return engineclass(pool, dialect, u, **engine_args)
def pool_threadlocal(self):
raise NotImplementedError()
- def get_pool_provider(self, url, pool):
- raise NotImplementedError()
-
def get_engine_cls(self):
raise NotImplementedError()
@@ -123,9 +116,6 @@ class PlainEngineStrategy(DefaultEngineStrategy):
def pool_threadlocal(self):
return False
- def get_pool_provider(self, url, pool):
- return default.PoolConnectionProvider(url, pool)
-
def get_engine_cls(self):
return base.Engine
@@ -138,9 +128,6 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy):
def pool_threadlocal(self):
return True
- def get_pool_provider(self, url, pool):
- return threadlocal.TLocalConnectionProvider(url, pool)
-
def get_engine_cls(self):
return threadlocal.TLEngine
@@ -195,4 +182,4 @@ class MockEngineStrategy(EngineStrategy):
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
-MockEngineStrategy() \ No newline at end of file
+MockEngineStrategy()
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py
index 2bbb1ed43..b6ba54ea5 100644
--- a/lib/sqlalchemy/engine/threadlocal.py
+++ b/lib/sqlalchemy/engine/threadlocal.py
@@ -1,8 +1,7 @@
-from sqlalchemy import schema, exceptions, util, sql, types
-import StringIO, sys, re
-from sqlalchemy.engine import base, default
+from sqlalchemy import util
+from sqlalchemy.engine import base
-"""Provide a thread-local transactional wrapper around the basic ComposedSQLEngine.
+"""Provide a thread-local transactional wrapper around the root Engine class.
Multiple calls to engine.connect() will return the same connection for
the same thread. also provides begin/commit methods on the engine
@@ -70,11 +69,8 @@ class TLConnection(base.Connection):
self.__opencount += 1
return self
- def _create_transaction(self, parent):
- return TLTransaction(self, parent)
-
def _begin(self):
- return base.Connection.begin(self)
+ return TLTransaction(self)
def in_transaction(self):
return self.session.in_transaction()
@@ -91,7 +87,7 @@ class TLConnection(base.Connection):
self.__opencount = 0
base.Connection.close(self)
-class TLTransaction(base.Transaction):
+class TLTransaction(base.RootTransaction):
def _commit_impl(self):
base.Transaction.commit(self)
@@ -112,7 +108,7 @@ class TLEngine(base.Engine):
"""
def __init__(self, *args, **kwargs):
- """The TLEngine relies upon the ConnectionProvider having
+ """The TLEngine relies upon the Pool having
"threadlocal" behavior, so that once a connection is checked out
for the current thread, you get that same connection
repeatedly.
@@ -124,7 +120,7 @@ class TLEngine(base.Engine):
def raw_connection(self):
"""Return a DBAPI connection."""
- return self.connection_provider.get_connection()
+ return self.pool.connect()
def connect(self, **kwargs):
"""Return a Connection that is not thread-locally scoped.
@@ -133,7 +129,7 @@ class TLEngine(base.Engine):
ComposedSQLEngine.
"""
- return base.Connection(self, self.connection_provider.unique_connection())
+ return base.Connection(self, self.pool.unique_connection())
def _session(self):
if not hasattr(self.context, 'session'):
@@ -156,6 +152,3 @@ class TLEngine(base.Engine):
def rollback(self):
self.session.rollback()
-class TLocalConnectionProvider(default.PoolConnectionProvider):
- def unique_connection(self):
- return self._pool.unique_connection()
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index c5ad90ee9..1da76d7b2 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -1,10 +1,8 @@
-import re
-import cgi
-import sys
-import urllib
+"""Provide the URL object as well as the make_url parsing function."""
+
+import re, cgi, sys, urllib
from sqlalchemy import exceptions
-"""Provide the URL object as well as the make_url parsing function."""
class URL(object):
"""Represent the components of a URL used to connect to a database.
diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py
index 2fcf44f61..fa32a5fc3 100644
--- a/lib/sqlalchemy/ext/activemapper.py
+++ b/lib/sqlalchemy/ext/activemapper.py
@@ -1,11 +1,10 @@
-from sqlalchemy import create_session, relation, mapper, \
- join, ThreadLocalMetaData, class_mapper, \
- util, Integer
-from sqlalchemy import and_, or_
+from sqlalchemy import ThreadLocalMetaData, util, Integer
from sqlalchemy import Table, Column, ForeignKey
+from sqlalchemy.orm import class_mapper, relation, create_session
+
from sqlalchemy.ext.sessioncontext import SessionContext
from sqlalchemy.ext.assignmapper import assign_mapper
-from sqlalchemy import backref as create_backref
+from sqlalchemy.orm import backref as create_backref
import sqlalchemy
import inspect
@@ -14,7 +13,7 @@ import sys
#
# the "proxy" to the database engine... this can be swapped out at runtime
#
-metadata = ThreadLocalMetaData("activemapper")
+metadata = ThreadLocalMetaData()
try:
objectstore = sqlalchemy.objectstore
diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py
index 4708afd8d..238041702 100644
--- a/lib/sqlalchemy/ext/assignmapper.py
+++ b/lib/sqlalchemy/ext/assignmapper.py
@@ -1,51 +1,50 @@
-from sqlalchemy import mapper, util, Query, exceptions
+from sqlalchemy import util, exceptions
import types
-
-def monkeypatch_query_method(ctx, class_, name):
- def do(self, *args, **kwargs):
- query = Query(class_, session=ctx.current)
- return getattr(query, name)(*args, **kwargs)
- try:
- do.__name__ = name
- except:
- pass
- setattr(class_, name, classmethod(do))
-
-def monkeypatch_objectstore_method(ctx, class_, name):
+from sqlalchemy.orm import mapper
+
+def _monkeypatch_session_method(name, ctx, class_):
def do(self, *args, **kwargs):
session = ctx.current
- if name == "flush":
- # flush expects a list of objects
- self = [self]
return getattr(session, name)(self, *args, **kwargs)
try:
do.__name__ = name
except:
pass
- setattr(class_, name, do)
-
+ if not hasattr(class_, name):
+ setattr(class_, name, do)
+
def assign_mapper(ctx, class_, *args, **kwargs):
+ extension = kwargs.pop('extension', None)
+ if extension is not None:
+ extension = util.to_list(extension)
+ extension.append(ctx.mapper_extension)
+ else:
+ extension = ctx.mapper_extension
+
validate = kwargs.pop('validate', False)
+
if not isinstance(getattr(class_, '__init__'), types.MethodType):
def __init__(self, **kwargs):
for key, value in kwargs.items():
if validate:
- if not key in self.mapper.props:
+ if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
setattr(self, key, value)
class_.__init__ = __init__
- extension = kwargs.pop('extension', None)
- if extension is not None:
- extension = util.to_list(extension)
- extension.append(ctx.mapper_extension)
- else:
- extension = ctx.mapper_extension
+
+ class query(object):
+ def __getattr__(self, key):
+ return getattr(ctx.current.query(class_), key)
+ def __call__(self):
+ return ctx.current.query(class_)
+
+ if not hasattr(class_, 'query'):
+ class_.query = query()
+
+ for name in ['refresh', 'expire', 'delete', 'expunge', 'update']:
+ _monkeypatch_session_method(name, ctx, class_)
+
m = mapper(class_, extension=extension, *args, **kwargs)
class_.mapper = m
- class_.query = classmethod(lambda cls: Query(class_, session=ctx.current))
- for name in ['get', 'filter', 'filter_by', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']:
- monkeypatch_query_method(ctx, class_, name)
- for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']:
- monkeypatch_objectstore_method(ctx, class_, name)
return m
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index cdb814702..2dd807222 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -6,11 +6,10 @@ transparent proxied access to the endpoint of an association object.
See the example ``examples/association/proxied_association.py``.
"""
-from sqlalchemy.orm.attributes import InstrumentedList
+import weakref, itertools
import sqlalchemy.exceptions as exceptions
import sqlalchemy.orm as orm
import sqlalchemy.util as util
-import weakref
def association_proxy(targetcollection, attr, **kw):
"""Convenience function for use in mapped classes. Implements a Python
@@ -109,7 +108,7 @@ class AssociationProxy(object):
self.collection_class = None
def _get_property(self):
- return orm.class_mapper(self.owning_class).props[self.target_collection]
+ return orm.class_mapper(self.owning_class).get_property(self.target_collection)
def _target_class(self):
return self._get_property().mapper.class_
@@ -168,15 +167,7 @@ class AssociationProxy(object):
def _new(self, lazy_collection):
creator = self.creator and self.creator or self.target_class
-
- # Prefer class typing here to spot dicts with the required append()
- # method.
- collection = lazy_collection()
- if isinstance(collection.data, dict):
- self.collection_class = dict
- else:
- self.collection_class = util.duck_type_collection(collection.data)
- del collection
+ self.collection_class = util.duck_type_collection(lazy_collection())
if self.proxy_factory:
return self.proxy_factory(lazy_collection, creator, self.value_attr)
@@ -269,7 +260,33 @@ class _AssociationList(object):
return self._get(self.col[index])
def __setitem__(self, index, value):
- self._set(self.col[index], value)
+ if not isinstance(index, slice):
+ self._set(self.col[index], value)
+ else:
+ if index.stop is None:
+ stop = len(self)
+ elif index.stop < 0:
+ stop = len(self) + index.stop
+ else:
+ stop = index.stop
+ step = index.step or 1
+
+ rng = range(index.start or 0, stop, step)
+ if step == 1:
+ for i in rng:
+ del self[index.start]
+ i = index.start
+ for item in value:
+ self.insert(i, item)
+ i += 1
+ else:
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value),
+ len(rng)))
+ for i, item in zip(rng, value):
+ self._set(self.col[i], item)
def __delitem__(self, index):
del self.col[index]
@@ -291,9 +308,13 @@ class _AssociationList(object):
del self.col[start:end]
def __iter__(self):
- """Iterate over proxied values. For the actual domain objects,
- iterate over .col instead or just use the underlying collection
- directly from its property on the parent."""
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or
+ just use the underlying collection directly from its property
+ on the parent.
+ """
+
for member in self.col:
yield self._get(member)
raise StopIteration
@@ -304,6 +325,10 @@ class _AssociationList(object):
item = self._create(value, **kw)
self.col.append(item)
+ def count(self, value):
+ return sum([1 for _ in
+ itertools.ifilter(lambda v: v == value, iter(self))])
+
def extend(self, values):
for v in values:
self.append(v)
@@ -311,6 +336,26 @@ class _AssociationList(object):
def insert(self, index, value):
self.col[index:index] = [self._create(value)]
+ def pop(self, index=-1):
+ return self.getter(self.col.pop(index))
+
+ def remove(self, value):
+ for i, val in enumerate(self):
+ if val == value:
+ del self.col[i]
+ return
+ raise ValueError("value not in list")
+
+ def reverse(self):
+ """Not supported, use reversed(mylist)"""
+
+ raise NotImplementedError
+
+ def sort(self):
+ """Not supported, use sorted(mylist)"""
+
+ raise NotImplementedError
+
def clear(self):
del self.col[0:len(self.col)]
@@ -545,9 +590,7 @@ class _AssociationSet(object):
def add(self, value):
if value not in self:
- # must shove this through InstrumentedList.append() which will
- # eventually call the collection_class .add()
- self.col.append(self._create(value))
+ self.col.add(self._create(value))
# for discard and remove, choosing a more expensive check strategy rather
# than call self.creator()
@@ -567,12 +610,7 @@ class _AssociationSet(object):
def pop(self):
if not self.col:
raise KeyError('pop from an empty set')
- # grumble, pop() is borked on InstrumentedList (#548)
- if isinstance(self.col, InstrumentedList):
- member = list(self.col)[0]
- self.col.remove(member)
- else:
- member = self.col.pop()
+ member = self.col.pop()
return self._get(member)
def update(self, other):
diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py
deleted file mode 100644
index b81702fc4..000000000
--- a/lib/sqlalchemy/ext/proxy.py
+++ /dev/null
@@ -1,113 +0,0 @@
-try:
- from threading import local
-except ImportError:
- from sqlalchemy.util import ThreadLocal as local
-
-from sqlalchemy import sql
-from sqlalchemy.engine import create_engine, Engine
-
-__all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine']
-
-class BaseProxyEngine(sql.Executor):
- """Basis for all proxy engines."""
-
- def get_engine(self):
- raise NotImplementedError
-
- def set_engine(self, engine):
- raise NotImplementedError
-
- engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e))
-
- def execute_compiled(self, *args, **kwargs):
- """Override superclass behaviour.
-
- This method is required to be present as it overrides the
- `execute_compiled` present in ``sql.Engine``.
- """
-
- return self.get_engine().execute_compiled(*args, **kwargs)
-
- def compiler(self, *args, **kwargs):
- """Override superclass behaviour.
-
- This method is required to be present as it overrides the
- `compiler` method present in ``sql.Engine``.
- """
-
- return self.get_engine().compiler(*args, **kwargs)
-
- def __getattr__(self, attr):
- """Provide proxying for methods that are not otherwise present on this ``BaseProxyEngine``.
-
- Note that methods which are present on the base class
- ``sql.Engine`` will **not** be proxied through this, and must
- be explicit on this class.
- """
-
- # call get_engine() to give subclasses a chance to change
- # connection establishment behavior
- e = self.get_engine()
- if e is not None:
- return getattr(e, attr)
- raise AttributeError("No connection established in ProxyEngine: "
- " no access to %s" % attr)
-
-class AutoConnectEngine(BaseProxyEngine):
- """An SQLEngine proxy that automatically connects when necessary."""
-
- def __init__(self, dburi, **kwargs):
- BaseProxyEngine.__init__(self)
- self.dburi = dburi
- self.kwargs = kwargs
- self._engine = None
-
- def get_engine(self):
- if self._engine is None:
- if callable(self.dburi):
- dburi = self.dburi()
- else:
- dburi = self.dburi
- self._engine = create_engine(dburi, **self.kwargs)
- return self._engine
-
-
-class ProxyEngine(BaseProxyEngine):
- """Engine proxy for lazy and late initialization.
-
- This engine will delegate access to a real engine set with connect().
- """
-
- def __init__(self, **kwargs):
- BaseProxyEngine.__init__(self)
- # create the local storage for uri->engine map and current engine
- self.storage = local()
- self.kwargs = kwargs
-
- def connect(self, *args, **kwargs):
- """Establish connection to a real engine."""
-
- kwargs.update(self.kwargs)
- if not kwargs:
- key = repr(args)
- else:
- key = "%s, %s" % (repr(args), repr(sorted(kwargs.items())))
- try:
- map = self.storage.connection
- except AttributeError:
- self.storage.connection = {}
- self.storage.engine = None
- map = self.storage.connection
- try:
- self.storage.engine = map[key]
- except KeyError:
- map[key] = create_engine(*args, **kwargs)
- self.storage.engine = map[key]
-
- def get_engine(self):
- if not hasattr(self.storage, 'engine') or self.storage.engine is None:
- raise AttributeError("No connection established")
- return self.storage.engine
-
- def set_engine(self, engine):
- self.storage.engine = engine
diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py
index 68538f3cb..1920b6f92 100644
--- a/lib/sqlalchemy/ext/selectresults.py
+++ b/lib/sqlalchemy/ext/selectresults.py
@@ -1,212 +1,28 @@
+"""SelectResults has been rolled into Query. This class is now just a placeholder."""
+
import sqlalchemy.sql as sql
import sqlalchemy.orm as orm
class SelectResultsExt(orm.MapperExtension):
"""a MapperExtension that provides SelectResults functionality for the
results of query.select_by() and query.select()"""
+
def select_by(self, query, *args, **params):
- return SelectResults(query, query.join_by(*args, **params))
+ q = query
+ for a in args:
+ q = q.filter(a)
+ return q.filter_by(**params)
+
def select(self, query, arg=None, **kwargs):
if isinstance(arg, sql.FromClause) and arg.supports_execution():
return orm.EXT_PASS
else:
- return SelectResults(query, arg, ops=kwargs)
-
-class SelectResults(object):
- """Build a query one component at a time via separate method
- calls, each call transforming the previous ``SelectResults``
- instance into a new ``SelectResults`` instance with further
- limiting criterion added. When interpreted in an iterator context
- (such as via calling ``list(selectresults)``), executes the query.
- """
-
- def __init__(self, query, clause=None, ops={}, joinpoint=None):
- """Construct a new ``SelectResults`` using the given ``Query``
- object and optional ``WHERE`` clause. `ops` is an optional
- dictionary of bind parameter values.
- """
-
- self._query = query
- self._clause = clause
- self._ops = {}
- self._ops.update(ops)
- self._joinpoint = joinpoint or (self._query.table, self._query.mapper)
-
- def options(self,*args, **kwargs):
- """Apply mapper options to the underlying query.
-
- See also ``Query.options``.
- """
-
- new = self.clone()
- new._query = new._query.options(*args, **kwargs)
- return new
-
- def count(self):
- """Execute the SQL ``count()`` function against the ``SelectResults`` criterion."""
-
- return self._query.count(self._clause, **self._ops)
-
- def _col_aggregate(self, col, func):
- """Execute ``func()`` function against the given column.
-
- For performance, only use subselect if `order_by` attribute is set.
- """
-
- if self._ops.get('order_by'):
- s1 = sql.select([col], self._clause, **self._ops).alias('u')
- return sql.select([func(s1.corresponding_column(col))]).scalar()
- else:
- return sql.select([func(col)], self._clause, **self._ops).scalar()
-
- def min(self, col):
- """Execute the SQL ``min()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.min)
-
- def max(self, col):
- """Execute the SQL ``max()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.max)
-
- def sum(self, col):
- """Execute the SQL ``sum()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.sum)
-
- def avg(self, col):
- """Execute the SQL ``avg()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.avg)
-
- def clone(self):
- """Create a copy of this ``SelectResults``."""
-
- return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint)
-
- def filter(self, clause):
- """Apply an additional ``WHERE`` clause against the query."""
-
- new = self.clone()
- new._clause = sql.and_(self._clause, clause)
- return new
-
- def select(self, clause):
- return self.filter(clause)
-
- def select_by(self, *args, **kwargs):
- return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1]))
-
- def order_by(self, order_by):
- """Apply an ``ORDER BY`` to the query."""
-
- new = self.clone()
- new._ops['order_by'] = order_by
- return new
-
- def limit(self, limit):
- """Apply a ``LIMIT`` to the query."""
-
- return self[:limit]
-
- def offset(self, offset):
- """Apply an ``OFFSET`` to the query."""
-
- return self[offset:]
-
- def distinct(self):
- """Apply a ``DISTINCT`` to the query."""
-
- new = self.clone()
- new._ops['distinct'] = True
- return new
-
- def list(self):
- """Return the results represented by this ``SelectResults`` as a list.
-
- This results in an execution of the underlying query.
- """
-
- return list(self)
-
- def select_from(self, from_obj):
- """Set the `from_obj` parameter of the query.
-
- `from_obj` is a list of one or more tables.
- """
-
- new = self.clone()
- new._ops['from_obj'] = from_obj
- return new
-
- def join_to(self, prop):
- """Join the table of this ``SelectResults`` to the table located against the given property name.
-
- Subsequent calls to join_to or outerjoin_to will join against
- the rightmost table located from the previous `join_to` or
- `outerjoin_to` call, searching for the property starting with
- the rightmost mapper last located.
- """
-
- new = self.clone()
- (clause, mapper) = self._join_to(prop, outerjoin=False)
- new._ops['from_obj'] = [clause]
- new._joinpoint = (clause, mapper)
- return new
-
- def outerjoin_to(self, prop):
- """Outer join the table of this ``SelectResults`` to the
- table located against the given property name.
-
- Subsequent calls to join_to or outerjoin_to will join against
- the rightmost table located from the previous ``join_to`` or
- ``outerjoin_to`` call, searching for the property starting with
- the rightmost mapper last located.
- """
-
- new = self.clone()
- (clause, mapper) = self._join_to(prop, outerjoin=True)
- new._ops['from_obj'] = [clause]
- new._joinpoint = (clause, mapper)
- return new
-
- def _join_to(self, prop, outerjoin=False):
- [keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1])
- clause = self._joinpoint[0]
- mapper = self._joinpoint[1]
- for key in keys:
- prop = mapper.props[key]
- if outerjoin:
- clause = clause.outerjoin(prop.select_table, prop.get_join(mapper))
- else:
- clause = clause.join(prop.select_table, prop.get_join(mapper))
- mapper = prop.mapper
- return (clause, mapper)
-
- def compile(self):
- return self._query.compile(self._clause, **self._ops)
-
- def __getitem__(self, item):
- if isinstance(item, slice):
- start = item.start
- stop = item.stop
- if (isinstance(start, int) and start < 0) or \
- (isinstance(stop, int) and stop < 0):
- return list(self)[item]
- else:
- res = self.clone()
- if start is not None and stop is not None:
- res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start))
- elif start is None and stop is not None:
- res._ops.update(dict(limit=stop))
- elif start is not None and stop is None:
- res._ops.update(dict(offset=self._ops.get('offset', 0)+start))
- if item.step is not None:
- return list(res)[None:None:item.step]
- else:
- return res
- else:
- return list(self[item:item+1])[0]
-
- def __iter__(self):
- return iter(self._query.select_whereclause(self._clause, **self._ops))
+ if arg is not None:
+ query = query.filter(arg)
+ return query._legacy_select_kwargs(**kwargs)
+
+def SelectResults(query, clause=None, ops={}):
+ if clause is not None:
+ query = query.filter(clause)
+ query = query.options(orm.extension(SelectResultsExt()))
+ return query._legacy_select_kwargs(**ops)
diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py
index 2f81e55d2..fcbf29c3f 100644
--- a/lib/sqlalchemy/ext/sessioncontext.py
+++ b/lib/sqlalchemy/ext/sessioncontext.py
@@ -1,5 +1,5 @@
from sqlalchemy.util import ScopedRegistry
-from sqlalchemy.orm.mapper import MapperExtension
+from sqlalchemy.orm import create_session, object_session, MapperExtension, EXT_PASS
__all__ = ['SessionContext', 'SessionContextExt']
@@ -15,16 +15,18 @@ class SessionContext(object):
engine = create_engine(...)
def session_factory():
- return Session(bind_to=engine)
+ return Session(bind=engine)
context = SessionContext(session_factory)
s = context.current # get thread-local session
- context.current = Session(bind_to=other_engine) # set current session
+ context.current = Session(bind=other_engine) # set current session
del context.current # discard the thread-local session (a new one will
# be created on the next call to context.current)
"""
- def __init__(self, session_factory, scopefunc=None):
+ def __init__(self, session_factory=None, scopefunc=None):
+ if session_factory is None:
+ session_factory = create_session
self.registry = ScopedRegistry(session_factory, scopefunc)
super(SessionContext, self).__init__()
@@ -60,3 +62,21 @@ class SessionContextExt(MapperExtension):
def get_session(self):
return self.context.current
+
+ def init_instance(self, mapper, class_, instance, args, kwargs):
+ session = kwargs.pop('_sa_session', self.context.current)
+ session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+ return EXT_PASS
+
+ def init_failed(self, mapper, class_, instance, args, kwargs):
+ object_session(instance).expunge(instance)
+ return EXT_PASS
+
+ def dispose_class(self, mapper, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+
+
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index 04e5b49f7..756b5e1e7 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -310,8 +310,8 @@ Boring tests here. Nothing of real expository value.
"""
from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
-from sqlalchemy.ext.assignmapper import assign_mapper
from sqlalchemy.exceptions import *
@@ -392,7 +392,7 @@ class SelectableClassType(type):
def update(cls, whereclause=None, values=None, **kwargs):
_ddl_error(cls)
- def _selectable(cls):
+ def __selectable__(cls):
return cls._table
def __getattr__(cls, attr):
@@ -434,9 +434,7 @@ def _selectable_name(selectable):
return x
def class_for_table(selectable, **mapper_kwargs):
- if not hasattr(selectable, '_selectable') \
- or selectable._selectable() != selectable:
- raise ArgumentError('class_for_table requires a selectable as its argument')
+ selectable = sql._selectable(selectable)
mapname = 'Mapped' + _selectable_name(selectable)
if isinstance(selectable, Table):
klass = TableClassType(mapname, (object,), {})
@@ -520,7 +518,7 @@ class SqlSoup:
def with_labels(self, item):
# TODO give meaningful aliases
- return self.map(item._selectable().select(use_labels=True).alias('foo'))
+ return self.map(sql._selectable(item).select(use_labels=True).alias('foo'))
def join(self, *args, **kwargs):
j = join(*args, **kwargs)
@@ -539,6 +537,9 @@ class SqlSoup:
t = None
self._cache[attr] = t
return t
+
+ def __repr__(self):
+ return 'SqlSoup(%r)' % self._metadata
if __name__ == '__main__':
import doctest
diff --git a/lib/sqlalchemy/mods/legacy_session.py b/lib/sqlalchemy/mods/legacy_session.py
deleted file mode 100644
index e21a5634b..000000000
--- a/lib/sqlalchemy/mods/legacy_session.py
+++ /dev/null
@@ -1,176 +0,0 @@
-"""A plugin that emulates 0.1 Session behavior."""
-
-import sqlalchemy.orm.objectstore as objectstore
-import sqlalchemy.orm.unitofwork as unitofwork
-import sqlalchemy.util as util
-import sqlalchemy
-
-import sqlalchemy.mods.threadlocal
-
-class LegacySession(objectstore.Session):
- def __init__(self, nest_on=None, hash_key=None, **kwargs):
- super(LegacySession, self).__init__(**kwargs)
- self.parent_uow = None
- self.begin_count = 0
- self.nest_on = util.to_list(nest_on)
- self.__pushed_count = 0
-
- def was_pushed(self):
- if self.nest_on is None:
- return
- self.__pushed_count += 1
- if self.__pushed_count == 1:
- for n in self.nest_on:
- n.push_session()
-
- def was_popped(self):
- if self.nest_on is None or self.__pushed_count == 0:
- return
- self.__pushed_count -= 1
- if self.__pushed_count == 0:
- for n in self.nest_on:
- n.pop_session()
-
- class SessionTrans(object):
- """Returned by ``Session.begin()``, denotes a
- transactionalized UnitOfWork instance. Call ``commit()`
- on this to commit the transaction.
- """
-
- def __init__(self, parent, uow, isactive):
- self.__parent = parent
- self.__isactive = isactive
- self.__uow = uow
-
- isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.")
- parent = property(lambda s:s.__parent, doc="The parent Session of this SessionTrans object.")
- uow = property(lambda s:s.__uow, doc="The parent UnitOfWork corresponding to this transaction.")
-
- def begin(self):
- """Call ``begin()`` on the underlying ``Session`` object,
- returning a new no-op ``SessionTrans`` object.
- """
-
- if self.parent.uow is not self.uow:
- raise InvalidRequestError("This SessionTrans is no longer valid")
- return self.parent.begin()
-
- def commit(self):
- """Commit the transaction noted by this ``SessionTrans`` object."""
-
- self.__parent._trans_commit(self)
- self.__isactive = False
-
- def rollback(self):
- """Roll back the current UnitOfWork transaction, in the
- case that ``begin()`` has been called.
-
- The changes logged since the begin() call are discarded.
- """
-
- self.__parent._trans_rollback(self)
- self.__isactive = False
-
- def begin(self):
- """Begin a new UnitOfWork transaction and return a
- transaction-holding object.
-
- ``commit()`` or ``rollback()`` should be called on the returned object.
-
- ``commit()`` on the ``Session`` will do nothing while a
- transaction is pending, and further calls to ``begin()`` will
- return no-op transactional objects.
- """
-
- if self.parent_uow is not None:
- return LegacySession.SessionTrans(self, self.uow, False)
- self.parent_uow = self.uow
- self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map)
- return LegacySession.SessionTrans(self, self.uow, True)
-
- def commit(self, *objects):
- """Commit the current UnitOfWork transaction.
-
- Called with no arguments, this is only used for *implicit*
- transactions when there was no ``begin()``.
-
- If individual objects are submitted, then only those objects
- are committed, and the begin/commit cycle is not affected.
- """
-
- # if an object list is given, commit just those but dont
- # change begin/commit status
- if len(objects):
- self._commit_uow(*objects)
- self.uow.flush(self, *objects)
- return
- if self.parent_uow is None:
- self._commit_uow()
-
- def _trans_commit(self, trans):
- if trans.uow is self.uow and trans.isactive:
- try:
- self._commit_uow()
- finally:
- self.uow = self.parent_uow
- self.parent_uow = None
-
- def _trans_rollback(self, trans):
- if trans.uow is self.uow:
- self.uow = self.parent_uow
- self.parent_uow = None
-
- def _commit_uow(self, *obj):
- self.was_pushed()
- try:
- self.uow.flush(self, *obj)
- finally:
- self.was_popped()
-
-def begin():
- """Deprecated. Use ``s = Session(new_imap=False)``."""
-
- return objectstore.get_session().begin()
-
-def commit(*obj):
- """Deprecated. Use ``flush(*obj)``."""
-
- objectstore.get_session().flush(*obj)
-
-def uow():
- return objectstore.get_session()
-
-def push_session(sess):
- old = get_session()
- if getattr(sess, '_previous', None) is not None:
- raise InvalidRequestError("Given Session is already pushed onto some thread's stack")
- sess._previous = old
- session_registry.set(sess)
- sess.was_pushed()
-
-def pop_session():
- sess = get_session()
- old = sess._previous
- sess._previous = None
- session_registry.set(old)
- sess.was_popped()
- return old
-
-def using_session(sess, func):
- push_session(sess)
- try:
- return func()
- finally:
- pop_session()
-
-def install_plugin():
- objectstore.Session = LegacySession
- objectstore.session_registry = util.ScopedRegistry(objectstore.Session)
- objectstore.begin = begin
- objectstore.commit = commit
- objectstore.uow = uow
- objectstore.push_session = push_session
- objectstore.pop_session = pop_session
- objectstore.using_session = using_session
-
-install_plugin()
diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py
index ac8de9b06..25bfa2840 100644
--- a/lib/sqlalchemy/mods/selectresults.py
+++ b/lib/sqlalchemy/mods/selectresults.py
@@ -1,4 +1,4 @@
-from sqlalchemy.ext.selectresults import *
+from sqlalchemy.ext.selectresults import SelectResultsExt
from sqlalchemy.orm.mapper import global_extensions
def install_plugin():
diff --git a/lib/sqlalchemy/mods/threadlocal.py b/lib/sqlalchemy/mods/threadlocal.py
deleted file mode 100644
index c8043bc62..000000000
--- a/lib/sqlalchemy/mods/threadlocal.py
+++ /dev/null
@@ -1,53 +0,0 @@
-"""This plugin installs thread-local behavior at the ``Engine`` and ``Session`` level.
-
-The default ``Engine`` strategy will be *threadlocal*, producing
-``TLocalEngine`` instances for create_engine by default.
-
-With this engine, ``connect()`` method will return the same connection
-on the same thread, if it is already checked out from the pool. This
-greatly helps functions that call multiple statements to be able to
-easily use just one connection without explicit ``close`` statements
-on result handles.
-
-On the ``Session`` side, module-level methods will be installed within
-the objectstore module, such as ``flush()``, ``delete()``, etc. which
-call this method on the thread-local session.
-
-Note: this mod creates a global, thread-local session context named
-``sqlalchemy.objectstore``. All mappers created while this mod is
-installed will reference this global context when creating new mapped
-object instances.
-"""
-
-from sqlalchemy import util, engine, mapper
-from sqlalchemy.ext.sessioncontext import SessionContext
-import sqlalchemy.ext.assignmapper as assignmapper
-from sqlalchemy.orm.mapper import global_extensions
-from sqlalchemy.orm.session import Session
-import sqlalchemy
-import sys, types
-
-__all__ = ['Objectstore', 'assign_mapper']
-
-class Objectstore(object):
- def __init__(self, *args, **kwargs):
- self.context = SessionContext(*args, **kwargs)
- def __getattr__(self, name):
- return getattr(self.context.current, name)
- session = property(lambda s:s.context.current)
-
-def assign_mapper(class_, *args, **kwargs):
- assignmapper.assign_mapper(objectstore.context, class_, *args, **kwargs)
-
-objectstore = Objectstore(Session)
-def install_plugin():
- sqlalchemy.objectstore = objectstore
- global_extensions.append(objectstore.context.mapper_extension)
- engine.default_strategy = 'threadlocal'
- sqlalchemy.assign_mapper = assign_mapper
-
-def uninstall_plugin():
- engine.default_strategy = 'plain'
- global_extensions.remove(objectstore.context.mapper_extension)
-
-install_plugin()
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 7ef2da897..1982a94f7 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -11,58 +11,229 @@ packages and tying operations to class properties and constructors.
from sqlalchemy import exceptions
from sqlalchemy import util as sautil
-from sqlalchemy.orm.mapper import *
+from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, mapper_registry
+from sqlalchemy.orm.interfaces import SynonymProperty, MapperExtension, EXT_PASS, ExtensionOption, PropComparator
+from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty, CompositeProperty, BackRef
from sqlalchemy.orm import mapper as mapperlib
+from sqlalchemy.orm import collections, strategies
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.util import polymorphic_union
-from sqlalchemy.orm import properties, strategies, interfaces
from sqlalchemy.orm.session import Session as create_session
from sqlalchemy.orm.session import object_session, attribute_manager
-__all__ = ['relation', 'column_property', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', 'extension',
- 'mapper', 'clear_mappers', 'compile_mappers', 'clear_mapper', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query',
- 'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS', 'object_session'
- ]
+__all__ = ['relation', 'column_property', 'composite', 'backref', 'eagerload',
+ 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', 'undefer',
+ 'undefer_group', 'extension', 'mapper', 'clear_mappers',
+ 'compile_mappers', 'class_mapper', 'object_mapper',
+ 'MapperExtension', 'Query', 'polymorphic_union', 'create_session',
+ 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS',
+ 'object_session', 'PropComparator'
+ ]
-def relation(*args, **kwargs):
+def relation(argument, secondary=None, **kwargs):
"""Provide a relationship of a primary Mapper to a secondary Mapper.
- This corresponds to a parent-child or associative table relationship.
+ This corresponds to a parent-child or associative table relationship.
+ The constructed class is an instance of [sqlalchemy.orm.properties#PropertyLoader].
+
+ argument
+ a class or Mapper instance, representing the target of the relation.
+
+ secondary
+ for a many-to-many relationship, specifies the intermediary table. The
+ ``secondary`` keyword argument should generally only be used for a table
+ that is not otherwise expressed in any class mapping. In particular,
+ using the Association Object Pattern is
+ generally mutually exclusive against using the ``secondary`` keyword
+ argument.
+
+ \**kwargs follow:
+
+ association
+ Deprecated; as of version 0.3.0 the association keyword is synonomous
+ with applying the "all, delete-orphan" cascade to a "one-to-many"
+ relationship. SA can now automatically reconcile a "delete" and
+ "insert" operation of two objects with the same "identity" in a flush()
+ operation into a single "update" statement, which is the pattern that
+ "association" used to indicate.
+
+ backref
+ indicates the name of a property to be placed on the related mapper's
+ class that will handle this relationship in the other direction,
+ including synchronizing the object attributes on both sides of the
+ relation. Can also point to a ``backref()`` construct for more
+ configurability.
+
+ cascade
+ a string list of cascade rules which determines how persistence
+ operations should be "cascaded" from parent to child.
+
+ collection_class
+ a class or function that returns a new list-holding object. will be
+ used in place of a plain list for storing elements.
+
+ foreign_keys
+ a list of columns which are to be used as "foreign key" columns.
+ this parameter should be used in conjunction with explicit
+ ``primaryjoin`` and ``secondaryjoin`` (if needed) arguments, and the
+ columns within the ``foreign_keys`` list should be present within
+ those join conditions. Normally, ``relation()`` will inspect the
+ columns within the join conditions to determine which columns are
+ the "foreign key" columns, based on information in the ``Table``
+ metadata. Use this argument when no ForeignKey's are present in the
+ join condition, or to override the table-defined foreign keys.
+
+ foreignkey
+ deprecated. use the ``foreign_keys`` argument for foreign key
+ specification, or ``remote_side`` for "directional" logic.
+
+ lazy=True
+ specifies how the related items should be loaded. a value of True
+ indicates they should be loaded lazily when the property is first
+ accessed. A value of False indicates they should be loaded by joining
+ against the parent object query, so parent and child are loaded in one
+ round trip (i.e. eagerly). A value of None indicates the related items
+ are not loaded by the mapper in any case; the application will manually
+ insert items into the list in some other way. In all cases, items added
+ or removed to the parent object's collection (or scalar attribute) will
+ cause the appropriate updates and deletes upon flush(), i.e. this
+ option only affects load operations, not save operations.
+
+ order_by
+ indicates the ordering that should be applied when loading these items.
+
+ passive_deletes=False
+ Indicates if lazy-loaders should not be executed during the ``flush()``
+ process, which normally occurs in order to locate all existing child
+ items when a parent item is to be deleted. Setting this flag to True is
+ appropriate when ``ON DELETE CASCADE`` rules have been set up on the
+ actual tables so that the database may handle cascading deletes
+ automatically. This strategy is useful particularly for handling the
+ deletion of objects that have very large (and/or deep) child-object
+ collections.
+
+ post_update
+ this indicates that the relationship should be handled by a second
+ UPDATE statement after an INSERT or before a DELETE. Currently, it also
+ will issue an UPDATE after the instance was UPDATEd as well, although
+ this technically should be improved. This flag is used to handle saving
+ bi-directional dependencies between two individual rows (i.e. each row
+ references the other), where it would otherwise be impossible to INSERT
+ or DELETE both rows fully since one row exists before the other. Use
+ this flag when a particular mapping arrangement will incur two rows
+ that are dependent on each other, such as a table that has a
+ one-to-many relationship to a set of child rows, and also has a column
+ that references a single child row within that list (i.e. both tables
+ contain a foreign key to each other). If a ``flush()`` operation returns
+ an error that a "cyclical dependency" was detected, this is a cue that
+ you might want to use ``post_update`` to "break" the cycle.
+
+ primaryjoin
+ a ClauseElement that will be used as the primary join of this child
+ object against the parent object, or in a many-to-many relationship the
+ join of the primary object to the association table. By default, this
+ value is computed based on the foreign key relationships of the parent
+ and child tables (or association table).
+
+ private=False
+ deprecated. setting ``private=True`` is the equivalent of setting
+ ``cascade="all, delete-orphan"``, and indicates the lifecycle of child
+ objects should be contained within that of the parent.
+
+ remote_side
+ used for self-referential relationships, indicates the column or list
+ of columns that form the "remote side" of the relationship.
+
+ secondaryjoin
+ a ClauseElement that will be used as the join of an association table
+ to the child object. By default, this value is computed based on the
+ foreign key relationships of the association and child tables.
+
+ uselist=(True|False)
+ a boolean that indicates if this property should be loaded as a list or
+ a scalar. In most cases, this value is determined automatically by
+ ``relation()``, based on the type and direction of the relationship - one
+ to many forms a list, many to one forms a scalar, many to many is a
+ list. If a scalar is desired where normally a list would be present,
+ such as a bi-directional one-to-one relationship, set uselist to False.
+
+ viewonly=False
+ when set to True, the relation is used only for loading objects within
+ the relationship, and has no effect on the unit-of-work flush process.
+ Relations with viewonly can specify any kind of join conditions to
+ provide additional views of related objects onto a parent object. Note
+ that the functionality of a viewonly relationship has its limits -
+ complicated join conditions may not compile into eager or lazy loaders
+ properly. If this is the case, use an alternative method.
+
"""
- if len(args) > 1 and isinstance(args[0], type):
- raise exceptions.ArgumentError("relation(class, table, **kwargs) is deprecated. Please use relation(class, **kwargs) or relation(mapper, **kwargs).")
- return _relation_loader(*args, **kwargs)
+ return PropertyLoader(argument, secondary=secondary, **kwargs)
+
+# return _relation_loader(argument, secondary=secondary, **kwargs)
+
+#def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs):
def column_property(*args, **kwargs):
"""Provide a column-level property for use with a Mapper.
+
+ Column-based properties can normally be applied to the mapper's
+ ``properties`` dictionary using the ``schema.Column`` element directly.
+ Use this function when the given column is not directly present within
+ the mapper's selectable; examples include SQL expressions, functions,
+ and scalar SELECT queries.
+
+ Columns that arent present in the mapper's selectable won't be persisted
+ by the mapper and are effectively "read-only" attributes.
+
+ \*cols
+ list of Column objects to be mapped.
- Normally, custom column-level properties that represent columns
- directly or indirectly present within the mapped selectable
- can just be added to the ``properties`` dictionary directly,
- in which case this function's usage is not necessary.
-
- In the case of a ``ColumnElement`` directly present within the
- ``properties`` dictionary, the given column is converted to be the exact column
- located within the mapped selectable, in the case that the mapped selectable
- is not the exact parent selectable of the given column, but shares a common
- base table relationship with that column.
+ group
+ a group name for this property when marked as deferred.
+
+ deferred
+ when True, the column property is "deferred", meaning that
+ it does not load immediately, and is instead loaded when the
+ attribute is first accessed on an instance. See also
+ [sqlalchemy.orm#deferred()].
+
+ """
- Use this function when the column expression being added does not
- correspond to any single column within the mapped selectable,
- such as a labeled function or scalar-returning subquery, to force the element
- to become a mapped property regardless of it not being present within the
- mapped selectable.
+ return ColumnProperty(*args, **kwargs)
+
+def composite(class_, *cols, **kwargs):
+ """Return a composite column-based property for use with a Mapper.
+
+ This is very much like a column-based property except the given class
+ is used to construct values composed of one or more columns. The class must
+ implement a constructor with positional arguments matching the order of
+ columns given, as well as a __colset__() method which returns its attributes
+ in column order.
- Note that persistence of instances is driven from the collection of columns
- within the mapped selectable, so column properties attached to a Mapper which have
- no direct correspondence to the mapped selectable will effectively be non-persisted
- attributes.
+ class\_
+ the "composite type" class.
+
+ \*cols
+ list of Column objects to be mapped.
+
+ group
+ a group name for this property when marked as deferred.
+
+ deferred
+ when True, the column property is "deferred", meaning that
+ it does not load immediately, and is instead loaded when the
+ attribute is first accessed on an instance. See also
+ [sqlalchemy.orm#deferred()].
+
+ comparator
+ an optional instance of [sqlalchemy.orm#PropComparator] which
+ provides SQL expression generation functions for this composite
+ type.
"""
- return properties.ColumnProperty(*args, **kwargs)
-def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs):
- return properties.PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, lazy=lazy, **kwargs)
+ return CompositeProperty(class_, *cols, **kwargs)
+
def backref(name, **kwargs):
"""Create a BackRef object with explicit arguments, which are the same arguments one
@@ -72,7 +243,7 @@ def backref(name, **kwargs):
place of a string argument.
"""
- return properties.BackRef(name, **kwargs)
+ return BackRef(name, **kwargs)
def deferred(*columns, **kwargs):
"""Return a ``DeferredColumnProperty``, which indicates this
@@ -82,15 +253,141 @@ def deferred(*columns, **kwargs):
Used with the `properties` dictionary sent to ``mapper()``.
"""
- return properties.ColumnProperty(deferred=True, *columns, **kwargs)
-
-def mapper(class_, table=None, *args, **params):
- """Return a new ``Mapper`` object.
-
- See the ``Mapper`` class for a description of arguments.
+ return ColumnProperty(deferred=True, *columns, **kwargs)
+
+def mapper(class_, local_table=None, *args, **params):
+ """Return a new [sqlalchemy.orm#Mapper] object.
+
+ class\_
+ The class to be mapped.
+
+ local_table
+ The table to which the class is mapped, or None if this
+ mapper inherits from another mapper using concrete table
+ inheritance.
+
+ entity_name
+ A name to be associated with the `class`, to allow alternate
+ mappings for a single class.
+
+ always_refresh
+ If True, all query operations for this mapped class will
+ overwrite all data within object instances that already
+ exist within the session, erasing any in-memory changes with
+ whatever information was loaded from the database. Usage
+ of this flag is highly discouraged; as an alternative,
+ see the method `populate_existing()` on [sqlalchemy.orm.query#Query].
+
+ allow_column_override
+ If True, allows the usage of a ``relation()`` which has the
+ same name as a column in the mapped table. The table column
+ will no longer be mapped.
+
+ allow_null_pks
+ Indicates that composite primary keys where one or more (but
+ not all) columns contain NULL is a valid primary key.
+ Primary keys which contain NULL values usually indicate that
+ a result row does not contain an entity and should be
+ skipped.
+
+ batch
+ Indicates that save operations of multiple entities can be
+ batched together for efficiency. setting to False indicates
+ that an instance will be fully saved before saving the next
+ instance, which includes inserting/updating all table rows
+ corresponding to the entity as well as calling all
+ ``MapperExtension`` methods corresponding to the save
+ operation.
+
+ column_prefix
+ A string which will be prepended to the `key` name of all
+ Columns when creating column-based properties from the given
+ Table. Does not affect explicitly specified column-based
+ properties
+
+ concrete
+ If True, indicates this mapper should use concrete table
+ inheritance with its parent mapper.
+
+ extension
+ A [sqlalchemy.orm#MapperExtension] instance or list of
+ ``MapperExtension`` instances which will be applied to all
+ operations by this ``Mapper``.
+
+ inherits
+ Another ``Mapper`` for which this ``Mapper`` will have an
+ inheritance relationship with.
+
+ inherit_condition
+ For joined table inheritance, a SQL expression (constructed
+ ``ClauseElement``) which will define how the two tables are
+ joined; defaults to a natural join between the two tables.
+
+ order_by
+ A single ``Column`` or list of ``Columns`` for which
+ selection operations should use as the default ordering for
+ entities. Defaults to the OID/ROWID of the table if any, or
+ the first primary key column of the table.
+
+ non_primary
+ Construct a ``Mapper`` that will define only the selection
+ of instances, not their persistence. Any number of non_primary
+ mappers may be created for a particular class.
+
+ polymorphic_on
+ Used with mappers in an inheritance relationship, a ``Column``
+ which will identify the class/mapper combination to be used
+ with a particular row. requires the polymorphic_identity
+ value to be set for all mappers in the inheritance
+ hierarchy.
+
+ _polymorphic_map
+ Used internally to propigate the full map of polymorphic
+ identifiers to surrogate mappers.
+
+ polymorphic_identity
+ A value which will be stored in the Column denoted by
+ polymorphic_on, corresponding to the *class identity* of
+ this mapper.
+
+ polymorphic_fetch
+ specifies how subclasses mapped through joined-table
+ inheritance will be fetched. options are 'union',
+ 'select', and 'deferred'. if the select_table argument
+ is present, defaults to 'union', otherwise defaults to
+ 'select'.
+
+ properties
+ A dictionary mapping the string names of object attributes
+ to ``MapperProperty`` instances, which define the
+ persistence behavior of that attribute. Note that the
+ columns in the mapped table are automatically converted into
+ ``ColumnProperty`` instances based on the `key` property of
+ each ``Column`` (although they can be overridden using this
+ dictionary).
+
+ primary_key
+ A list of ``Column`` objects which define the *primary key*
+ to be used against this mapper's selectable unit. This is
+ normally simply the primary key of the `local_table`, but
+ can be overridden here.
+
+ select_table
+ A [sqlalchemy.schema#Table] or any [sqlalchemy.sql#Selectable]
+ which will be used to select instances of this mapper's class.
+ usually used to provide polymorphic loading among several
+ classes in an inheritance hierarchy.
+
+ version_id_col
+ A ``Column`` which must have an integer type that will be
+ used to keep a running *version id* of mapped entities in
+ the database. this is used during save operations to ensure
+ that no other thread or process has updated the instance
+ during the lifetime of the entity, else a
+ ``ConcurrentModificationError`` exception is thrown.
"""
- return Mapper(class_, table, *args, **params)
+ return Mapper(class_, local_table, *args, **params)
def synonym(name, proxy=False):
"""Set up `name` as a synonym to another ``MapperProperty``.
@@ -98,7 +395,7 @@ def synonym(name, proxy=False):
Used with the `properties` dictionary sent to ``mapper()``.
"""
- return interfaces.SynonymProperty(name, proxy=proxy)
+ return SynonymProperty(name, proxy=proxy)
def compile_mappers():
"""Compile all mappers that have been defined.
@@ -120,32 +417,13 @@ def clear_mappers():
mapperlib._COMPILE_MUTEX.acquire()
try:
for mapper in mapper_registry.values():
- attribute_manager.reset_class_managed(mapper.class_)
- if hasattr(mapper.class_, 'c'):
- del mapper.class_.c
+ mapper.dispose()
mapper_registry.clear()
# TODO: either dont use ArgSingleton, or
# find a way to clear only ClassKey instances from it
sautil.ArgSingleton.instances.clear()
finally:
mapperlib._COMPILE_MUTEX.release()
-
-def clear_mapper(m):
- """Remove the given mapper from the storage of mappers.
-
- When a new mapper is created for the previous mapper's class, it
- will be used as that classes' new primary mapper.
- """
-
- mapperlib._COMPILE_MUTEX.acquire()
- try:
- del mapper_registry[m.class_key]
- attribute_manager.reset_class_managed(m.class_)
- if hasattr(m.class_, 'c'):
- del m.class_.c
- m.class_key.dispose()
- finally:
- mapperlib._COMPILE_MUTEX.release()
def extension(ext):
"""Return a ``MapperOption`` that will insert the given
@@ -166,6 +444,22 @@ def eagerload(name):
return strategies.EagerLazyOption(name, lazy=False)
+def eagerload_all(name):
+ """Return a ``MapperOption`` that will convert all
+ properties along the given dot-separated path into an
+ eager load.
+
+ e.g::
+ query.options(eagerload_all('orders.items.keywords'))...
+
+ will set all of 'orders', 'orders.items', and 'orders.items.keywords'
+ to load in one eager load.
+
+ Used with ``query.options()``.
+ """
+
+ return strategies.EagerLazyOption(name, lazy=False, chained=True)
+
def lazyload(name):
"""Return a ``MapperOption`` that will convert the property of the
given name into a lazy load.
@@ -175,6 +469,9 @@ def lazyload(name):
return strategies.EagerLazyOption(name, lazy=True)
+def fetchmode(name, type):
+ return strategies.FetchModeOption(name, type)
+
def noload(name):
"""Return a ``MapperOption`` that will convert the property of the
given name into a non-load.
@@ -250,64 +547,11 @@ def undefer(name):
return strategies.DeferredOption(name, defer=False)
+def undefer_group(name):
+ """Return a ``MapperOption`` that will convert the given
+ group of deferred column properties into a non-deferred (regular column) load.
-def cascade_mappers(*classes_or_mappers):
- """Attempt to create a series of ``relations()`` between mappers
- automatically, via introspecting the foreign key relationships of
- the underlying tables.
-
- Given a list of classes and/or mappers, identify the foreign key
- relationships between the given mappers or corresponding class
- mappers, and create ``relation()`` objects representing those
- relationships, including a backreference. Attempt to find the
- *secondary* table in a many-to-many relationship as well.
-
- The names of the relations will be a lowercase version of the
- related class. In the case of one-to-many or many-to-many, the
- name will be *pluralized*, which currently is based on the English
- language (i.e. an 's' or 'es' added to it).
-
- NOTE: this method usually works poorly, and its usage is generally
- not advised.
+ Used with ``query.options()``.
"""
-
- table_to_mapper = {}
- for item in classes_or_mappers:
- if isinstance(item, Mapper):
- m = item
- else:
- klass = item
- m = class_mapper(klass)
- table_to_mapper[m.mapped_table] = m
-
- def pluralize(name):
- # oh crap, do we need locale stuff now
- if name[-1] == 's':
- return name + "es"
- else:
- return name + "s"
-
- for table,mapper in table_to_mapper.iteritems():
- for fk in table.foreign_keys:
- if fk.column.table is table:
- continue
- secondary = None
- try:
- m2 = table_to_mapper[fk.column.table]
- except KeyError:
- if len(fk.column.table.primary_key):
- continue
- for sfk in fk.column.table.foreign_keys:
- if sfk.column.table is table:
- continue
- m2 = table_to_mapper.get(sfk.column.table)
- secondary = fk.column.table
- if m2 is None:
- continue
- if secondary:
- propname = pluralize(m2.class_.__name__.lower())
- propname2 = pluralize(mapper.class_.__name__.lower())
- else:
- propname = m2.class_.__name__.lower()
- propname2 = pluralize(mapper.class_.__name__.lower())
- mapper.add_property(propname, relation(m2, secondary=secondary, backref=propname2))
+ return strategies.UndeferGroupOption(name)
+
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9f8a04db8..47ff26085 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -4,37 +4,67 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util
-from sqlalchemy.orm import util as orm_util
-from sqlalchemy import logging, exceptions
import weakref
-class InstrumentedAttribute(object):
- """A property object that instruments attribute access on object instances.
+from sqlalchemy import util
+from sqlalchemy.orm import util as orm_util, interfaces, collections
+from sqlalchemy.orm.mapper import class_mapper
+from sqlalchemy import logging, exceptions
- All methods correspond to a single attribute on a particular
- class.
- """
- PASSIVE_NORESULT = object()
+PASSIVE_NORESULT = object()
+ATTR_WAS_SET = object()
- def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
+class InstrumentedAttribute(interfaces.PropComparator):
+ """attribute access for instrumented classes."""
+
+ def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs):
+ """Construct an InstrumentedAttribute.
+
+ class_
+ the class to be instrumented.
+
+ manager
+ AttributeManager managing this class
+
+ key
+ string name of the attribute
+
+ callable_
+ optional function which generates a callable based on a parent
+ instance, which produces the "default" values for a scalar or
+ collection attribute when it's first accessed, if not present already.
+
+ trackparent
+ if True, attempt to track if an instance has a parent attached to it
+ via this attribute
+
+ extension
+ an AttributeExtension object which will receive
+ set/delete/append/remove/etc. events
+
+ compare_function
+ a function that compares two values which are normally assignable to this
+ attribute
+
+ mutable_scalars
+ if True, the values which are normally assignable to this attribute can mutate,
+ and need to be compared against a copy of their original contents in order to
+ detect changes on the parent instance
+
+ comparator
+ a sql.Comparator to which class-level compare/math events will be sent
+
+ """
+
+ self.class_ = class_
self.manager = manager
self.key = key
- self.uselist = uselist
self.callable_ = callable_
- self.typecallable= typecallable
self.trackparent = trackparent
self.mutable_scalars = mutable_scalars
- if copy_function is None:
- if uselist:
- self.copy = lambda x:[y for y in x]
- else:
- # scalar values are assumed to be immutable unless a copy function
- # is passed
- self.copy = lambda x:x
- else:
- self.copy = lambda x:copy_function(x)
+ self.comparator = comparator
+ self.copy = None
if compare_function is None:
self.is_equal = lambda x,y: x == y
else:
@@ -42,7 +72,7 @@ class InstrumentedAttribute(object):
self.extensions = util.to_list(extension or [])
def __set__(self, obj, value):
- self.set(None, obj, value)
+ self.set(obj, value, None)
def __delete__(self, obj):
self.delete(None, obj)
@@ -52,17 +82,18 @@ class InstrumentedAttribute(object):
return self
return self.get(obj)
- def check_mutable_modified(self, obj):
- if self.mutable_scalars:
- h = self.get_history(obj, passive=True)
- if h is not None and h.is_modified():
- obj._state['modified'] = True
- return True
- else:
- return False
- else:
- return False
+ def clause_element(self):
+ return self.comparator.clause_element()
+
+ def expression_element(self):
+ return self.comparator.expression_element()
+
+ def operate(self, op, other, **kwargs):
+ return op(self.comparator, other, **kwargs)
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.comparator, **kwargs)
+
def hasparent(self, item, optimistic=False):
"""Return the boolean value of a `hasparent` flag attached to the given item.
@@ -98,8 +129,8 @@ class InstrumentedAttribute(object):
# get the current state. this may trigger a lazy load if
# passive is False.
- current = self.get(obj, passive=passive, raiseerr=False)
- if current is InstrumentedAttribute.PASSIVE_NORESULT:
+ current = self.get(obj, passive=passive)
+ if current is PASSIVE_NORESULT:
return None
return AttributeHistory(self, obj, current, passive=passive)
@@ -123,6 +154,14 @@ class InstrumentedAttribute(object):
else:
obj._state[('callable', self)] = callable_
+ def _get_callable(self, obj):
+ if ('callable', self) in obj._state:
+ return obj._state[('callable', self)]
+ elif self.callable_ is not None:
+ return self.callable_(obj)
+ else:
+ return None
+
def reset(self, obj):
"""Remove any per-instance callable functions corresponding to
this ``InstrumentedAttribute``'s attribute from the given
@@ -148,43 +187,21 @@ class InstrumentedAttribute(object):
except KeyError:
pass
- def _get_callable(self, obj):
- if obj._state.has_key(('callable', self)):
- return obj._state[('callable', self)]
- elif self.callable_ is not None:
- return self.callable_(obj)
- else:
- return None
-
- def _blank_list(self):
- if self.typecallable is not None:
- return self.typecallable()
- else:
- return []
+ def check_mutable_modified(self, obj):
+ return False
def initialize(self, obj):
- """Initialize this attribute on the given object instance.
+ """Initialize this attribute on the given object instance with an empty value."""
- If this is a list-based attribute, a new, blank list will be
- created. if a scalar attribute, the value will be initialized
- to None.
- """
-
- if self.uselist:
- l = InstrumentedList(self, obj, self._blank_list())
- obj.__dict__[self.key] = l
- return l
- else:
- obj.__dict__[self.key] = None
- return None
+ obj.__dict__[self.key] = None
+ return None
- def get(self, obj, passive=False, raiseerr=True):
+ def get(self, obj, passive=False):
"""Retrieve a value from the given object.
If a callable is assembled on this object's attribute, and
passive is False, the callable will be executed and the
- resulting value will be set as the new value for this
- attribute.
+ resulting value will be set as the new value for this attribute.
"""
try:
@@ -193,441 +210,301 @@ class InstrumentedAttribute(object):
state = obj._state
# if an instance-wide "trigger" was set, call that
# and start again
- if state.has_key('trigger'):
+ if 'trigger' in state:
trig = state['trigger']
del state['trigger']
trig()
- return self.get(obj, passive=passive, raiseerr=raiseerr)
-
- if self.uselist:
- callable_ = self._get_callable(obj)
- if callable_ is not None:
- if passive:
- return InstrumentedAttribute.PASSIVE_NORESULT
- self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
- values = callable_()
- l = InstrumentedList(self, obj, values, init=False)
-
- # if a callable was executed, then its part of the "committed state"
- # if any, so commit the newly loaded data
- orig = state.get('original', None)
- if orig is not None:
- orig.commit_attribute(self, obj, l)
-
+ return self.get(obj, passive=passive)
+
+ callable_ = self._get_callable(obj)
+ if callable_ is not None:
+ if passive:
+ return PASSIVE_NORESULT
+ self.logger.debug("Executing lazy callable on %s.%s" %
+ (orm_util.instance_str(obj), self.key))
+ value = callable_()
+ if value is not ATTR_WAS_SET:
+ return self.set_committed_value(obj, value)
else:
- # note that we arent raising AttributeErrors, just creating a new
- # blank list and setting it.
- # this might be a good thing to be changeable by options.
- l = InstrumentedList(self, obj, self._blank_list(), init=False)
- obj.__dict__[self.key] = l
- return l
- else:
- callable_ = self._get_callable(obj)
- if callable_ is not None:
- if passive:
- return InstrumentedAttribute.PASSIVE_NORESULT
- self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
- value = callable_()
- obj.__dict__[self.key] = value
-
- # if a callable was executed, then its part of the "committed state"
- # if any, so commit the newly loaded data
- orig = state.get('original', None)
- if orig is not None:
- orig.commit_attribute(self, obj)
- return value
- else:
- # note that we arent raising AttributeErrors, just returning None.
- # this might be a good thing to be changeable by options.
- return None
-
- def set(self, event, obj, value):
- """Set a value on the given object.
-
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``set()` operation and is used to control the depth of a
- circular setter operation.
- """
-
- if event is not self:
- state = obj._state
- # if an instance-wide "trigger" was set, call that
- if state.has_key('trigger'):
- trig = state['trigger']
- del state['trigger']
- trig()
- if self.uselist:
- value = InstrumentedList(self, obj, value)
- old = self.get(obj)
- obj.__dict__[self.key] = value
- state['modified'] = True
- if not self.uselist:
- if self.trackparent:
- if value is not None:
- self.sethasparent(value, True)
- if old is not None:
- self.sethasparent(old, False)
- for ext in self.extensions:
- ext.set(event or self, obj, value, old)
+ return obj.__dict__[self.key]
else:
- # mark all the old elements as detached from the parent
- old.list_replaced()
+ # Return a new, empty value
+ return self.initialize(obj)
- def delete(self, event, obj):
- """Delete a value from the given object.
+ def append(self, obj, value, initiator):
+ self.set(obj, value, initiator)
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``delete()`` operation and is used to control the depth of a
- circular delete operation.
- """
-
- if event is not self:
- try:
- if not self.uselist and (self.trackparent or len(self.extensions)):
- old = self.get(obj)
- del obj.__dict__[self.key]
- except KeyError:
- # TODO: raise this? not consistent with get() ?
- raise AttributeError(self.key)
- obj._state['modified'] = True
- if not self.uselist:
- if self.trackparent:
- if old is not None:
- self.sethasparent(old, False)
- for ext in self.extensions:
- ext.delete(event or self, obj, old)
-
- def append(self, event, obj, value):
- """Append an element to a list based element or sets a scalar
- based element to the given value.
-
- Used by ``GenericBackrefExtension`` to *append* an item
- independent of list/scalar semantics.
-
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``append()`` operation and is used to control the depth of a
- circular append operation.
- """
+ def remove(self, obj, value, initiator):
+ self.set(obj, None, initiator)
- if self.uselist:
- if event is not self:
- self.get(obj).append_with_event(value, event)
- else:
- self.set(event, obj, value)
-
- def remove(self, event, obj, value):
- """Remove an element from a list based element or sets a
- scalar based element to None.
-
- Used by ``GenericBackrefExtension`` to *remove* an item
- independent of list/scalar semantics.
+ def set(self, obj, value, initiator):
+ raise NotImplementedError()
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``remove()`` operation and is used to control the depth of a
- circular remove operation.
+ def set_committed_value(self, obj, value):
+ """set an attribute value on the given instance and 'commit' it.
+
+ this indicates that the given value is the "persisted" value,
+ and history will be logged only if a newly set value is not
+ equal to this value.
+
+ this is typically used by deferred/lazy attribute loaders
+ to set object attributes after the initial load.
"""
- if self.uselist:
- if event is not self:
- self.get(obj).remove_with_event(value, event)
- else:
- self.set(event, obj, None)
+ state = obj._state
+ orig = state.get('original', None)
+ if orig is not None:
+ orig.commit_attribute(self, obj, value)
+ # remove per-instance callable, if any
+ state.pop(('callable', self), None)
+ obj.__dict__[self.key] = value
+ return value
- def append_event(self, event, obj, value):
- """Called by ``InstrumentedList`` when an item is appended."""
+ def set_raw_value(self, obj, value):
+ obj.__dict__[self.key] = value
+ return value
+ def fire_append_event(self, obj, value, initiator):
obj._state['modified'] = True
if self.trackparent and value is not None:
self.sethasparent(value, True)
for ext in self.extensions:
- ext.append(event or self, obj, value)
-
- def remove_event(self, event, obj, value):
- """Called by ``InstrumentedList`` when an item is removed."""
+ ext.append(obj, value, initiator or self)
+ def fire_remove_event(self, obj, value, initiator):
obj._state['modified'] = True
if self.trackparent and value is not None:
self.sethasparent(value, False)
for ext in self.extensions:
- ext.delete(event or self, obj, value)
+ ext.remove(obj, value, initiator or self)
+
+ def fire_replace_event(self, obj, value, previous, initiator):
+ obj._state['modified'] = True
+ if self.trackparent:
+ if value is not None:
+ self.sethasparent(value, True)
+ if previous is not None:
+ self.sethasparent(previous, False)
+ for ext in self.extensions:
+ ext.set(obj, value, previous, initiator or self)
+
+ property = property(lambda s: class_mapper(s.class_).get_property(s.key),
+ doc="the MapperProperty object associated with this attribute")
InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute)
+
+class InstrumentedScalarAttribute(InstrumentedAttribute):
+ """represents a scalar-holding InstrumentedAttribute."""
-class InstrumentedList(object):
- """Instrument a list-based attribute.
-
- All mutator operations (i.e. append, remove, etc.) will fire off
- events to the ``InstrumentedAttribute`` that manages the object's
- attribute. Those events in turn trigger things like backref
- operations and whatever is implemented by
- ``do_list_value_changed`` on ``InstrumentedAttribute``.
-
- Note that this list does a lot less than earlier versions of SA
- list-based attributes, which used ``HistoryArraySet``. This list
- wrapper does **not** maintain setlike semantics, meaning you can add
- as many duplicates as you want (which can break a lot of SQL), and
- also does not do anything related to history tracking.
-
- Please see ticket #213 for information on the future of this
- class, where it will be broken out into more collection-specific
- subtypes.
- """
+ def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
+ super(InstrumentedScalarAttribute, self).__init__(class_, manager, key,
+ callable_, trackparent=trackparent, extension=extension,
+ compare_function=compare_function, **kwargs)
+ self.mutable_scalars = mutable_scalars
- def __init__(self, attr, obj, data, init=True):
- self.attr = attr
- # this weakref is to prevent circular references between the parent object
- # and the list attribute, which interferes with immediate garbage collection.
- self.__obj = weakref.ref(obj)
- self.key = attr.key
-
- # adapt to lists or sets
- # TODO: make three subclasses of InstrumentedList that come off from a
- # metaclass, based on the type of data sent in
- if attr.typecallable is not None:
- self.data = attr.typecallable()
- else:
- self.data = data or attr._blank_list()
-
- if isinstance(self.data, list):
- self._data_appender = self.data.append
- self._clear_data = self._clear_list
- elif isinstance(self.data, util.Set):
- self._data_appender = self.data.add
- self._clear_data = self._clear_set
- elif isinstance(self.data, dict):
- if hasattr(self.data, 'append'):
- self._data_appender = self.data.append
- else:
- raise exceptions.ArgumentError("Dictionary collection class '%s' must implement an append() method" % type(self.data).__name__)
- self._clear_data = self._clear_dict
- else:
- if hasattr(self.data, 'append'):
- self._data_appender = self.data.append
- elif hasattr(self.data, 'add'):
- self._data_appender = self.data.add
- else:
- raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no append() or add() method" % type(self.data).__name__)
+ if copy_function is None:
+ copy_function = self.__copy
+ self.copy = copy_function
- if hasattr(self.data, 'clear'):
- self._clear_data = self._clear_set
- else:
- raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no clear() method" % type(self.data).__name__)
-
- if data is not None and data is not self.data:
- for elem in data:
- self._data_appender(elem)
-
+ def __copy(self, item):
+ # scalar values are assumed to be immutable unless a copy function
+ # is passed
+ return item
- if init:
- for x in self.data:
- self.__setrecord(x)
+ def __delete__(self, obj):
+ old = self.get(obj)
+ del obj.__dict__[self.key]
+ self.fire_remove_event(obj, old, self)
- def list_replaced(self):
- """Fire off delete event handlers for each item in the list
- but doesnt affect the original data list.
- """
+ def check_mutable_modified(self, obj):
+ if self.mutable_scalars:
+ h = self.get_history(obj, passive=True)
+ if h is not None and h.is_modified():
+ obj._state['modified'] = True
+ return True
+ else:
+ return False
+ else:
+ return False
- [self.__delrecord(x) for x in self.data]
+ def set(self, obj, value, initiator):
+ """Set a value on the given object.
- def clear(self):
- """Clear all items in this InstrumentedList and fires off
- delete event handlers for each item.
+ `initiator` is the ``InstrumentedAttribute`` that initiated the
+ ``set()` operation and is used to control the depth of a circular
+ setter operation.
"""
- self._clear_data()
-
- def _clear_dict(self):
- [self.__delrecord(x) for x in self.data.values()]
- self.data.clear()
-
- def _clear_set(self):
- [self.__delrecord(x) for x in self.data]
- self.data.clear()
-
- def _clear_list(self):
- self[:] = []
-
- def __getstate__(self):
- """Implemented to allow pickling, since `__obj` is a weakref,
- also the ``InstrumentedAttribute`` has callables attached to
- it.
- """
+ if initiator is self:
+ return
- return {'key':self.key, 'obj':self.obj, 'data':self.data}
+ state = obj._state
+ # if an instance-wide "trigger" was set, call that
+ if 'trigger' in state:
+ trig = state['trigger']
+ del state['trigger']
+ trig()
- def __setstate__(self, d):
- """Implemented to allow pickling, since `__obj` is a weakref,
- also the ``InstrumentedAttribute`` has callables attached to it.
- """
+ old = self.get(obj)
+ obj.__dict__[self.key] = value
+ self.fire_replace_event(obj, value, old, initiator)
- self.key = d['key']
- self.__obj = weakref.ref(d['obj'])
- self.data = d['data']
- self.attr = getattr(d['obj'].__class__, self.key)
+ type = property(lambda self: self.property.columns[0].type)
- obj = property(lambda s:s.__obj())
+
+class InstrumentedCollectionAttribute(InstrumentedAttribute):
+ """A collection-holding attribute that instruments changes in membership.
- def unchanged_items(self):
- """Deprecated."""
+ InstrumentedCollectionAttribute holds an arbitrary, user-specified
+ container object (defaulting to a list) and brokers access to the
+ CollectionAdapter, a "view" onto that object that presents consistent
+ bag semantics to the orm layer independent of the user data implementation.
+ """
+
+ def __init__(self, class_, manager, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
+ super(InstrumentedCollectionAttribute, self).__init__(class_, manager,
+ key, callable_, trackparent=trackparent, extension=extension,
+ compare_function=compare_function, **kwargs)
- return self.attr.get_history(self.obj).unchanged_items
+ if copy_function is None:
+ copy_function = self.__copy
+ self.copy = copy_function
- def added_items(self):
- """Deprecated."""
+ if typecallable is None:
+ typecallable = list
+ self.collection_factory = \
+ collections._prepare_instrumentation(typecallable)
+ self.collection_interface = \
+ util.duck_type_collection(self.collection_factory())
- return self.attr.get_history(self.obj).added_items
+ def __copy(self, item):
+ return [y for y in list(collections.collection_adapter(item))]
- def deleted_items(self):
- """Deprecated."""
+ def __set__(self, obj, value):
+ """Replace the current collection with a new one."""
- return self.attr.get_history(self.obj).deleted_items
+ setting_type = util.duck_type_collection(value)
- def __iter__(self):
- return iter(self.data)
+ if value is None or setting_type != self.collection_interface:
+ raise exceptions.ArgumentError(
+ "Incompatible collection type on assignment: %s is not %s-like" %
+ (type(value).__name__, self.collection_interface.__name__))
- def __repr__(self):
- return repr(self.data)
+ if hasattr(value, '_sa_adapter'):
+ self.set(obj, list(getattr(value, '_sa_adapter')), None)
+ elif setting_type == dict:
+ self.set(obj, value.values(), None)
+ else:
+ self.set(obj, value, None)
- def __getattr__(self, attr):
- """Proxy unknown methods and attributes to the underlying
- data array. This allows custom list classes to be used.
- """
+ def __delete__(self, obj):
+ if self.key not in obj.__dict__:
+ return
- return getattr(self.data, attr)
+ obj._state['modified'] = True
- def __setrecord(self, item, event=None):
- self.attr.append_event(event, self.obj, item)
- return True
+ collection = self._get_collection(obj)
+ collection.clear_with_event()
+ del obj.__dict__[self.key]
- def __delrecord(self, item, event=None):
- self.attr.remove_event(event, self.obj, item)
- return True
+ def initialize(self, obj):
+ """Initialize this attribute on the given object instance with an empty collection."""
- def append_with_event(self, item, event):
- self.__setrecord(item, event)
- self._data_appender(item)
+ _, user_data = self._build_collection(obj)
+ obj.__dict__[self.key] = user_data
+ return user_data
- def append_without_event(self, item):
- self._data_appender(item)
+ def append(self, obj, value, initiator):
+ if initiator is self:
+ return
+ collection = self._get_collection(obj)
+ collection.append_with_event(value, initiator)
- def remove_with_event(self, item, event):
- self.__delrecord(item, event)
- self.data.remove(item)
+ def remove(self, obj, value, initiator):
+ if initiator is self:
+ return
+ collection = self._get_collection(obj)
+ collection.remove_with_event(value, initiator)
- def append(self, item, _mapper_nohistory=False):
- """Fire off dependent events, and appends the given item to the underlying list.
+ def set(self, obj, value, initiator):
+ """Set a value on the given object.
- `_mapper_nohistory` is a backwards compatibility hack; call
- ``append_without_event`` instead.
+ `initiator` is the ``InstrumentedAttribute`` that initiated the
+ ``set()` operation and is used to control the depth of a circular
+ setter operation.
"""
- if _mapper_nohistory:
- self.append_without_event(item)
- else:
- self.__setrecord(item)
- self._data_appender(item)
-
- def __getitem__(self, i):
- return self.data[i]
-
- def __setitem__(self, i, item):
- if isinstance(i, slice):
- self.__setslice__(i.start, i.stop, item)
- else:
- self.__setrecord(item)
- self.data[i] = item
-
- def __delitem__(self, i):
- if isinstance(i, slice):
- self.__delslice__(i.start, i.stop)
- else:
- self.__delrecord(self.data[i], None)
- del self.data[i]
-
- def __lt__(self, other): return self.data < self.__cast(other)
-
- def __le__(self, other): return self.data <= self.__cast(other)
+ if initiator is self:
+ return
- def __eq__(self, other): return self.data == self.__cast(other)
+ state = obj._state
+ # if an instance-wide "trigger" was set, call that
+ if 'trigger' in state:
+ trig = state['trigger']
+ del state['trigger']
+ trig()
- def __ne__(self, other): return self.data != self.__cast(other)
+ old = self.get(obj)
+ old_collection = self._get_collection(obj, old)
- def __gt__(self, other): return self.data > self.__cast(other)
+ new_collection, user_data = self._build_collection(obj)
+ self._load_collection(obj, value or [], emit_events=True,
+ collection=new_collection)
- def __ge__(self, other): return self.data >= self.__cast(other)
+ obj.__dict__[self.key] = user_data
+ state['modified'] = True
- def __cast(self, other):
- if isinstance(other, InstrumentedList): return other.data
- else: return other
+ # mark all the old elements as detached from the parent
+ if old_collection:
+ old_collection.clear_with_event()
+ old_collection.unlink(old)
- def __cmp__(self, other):
- return cmp(self.data, self.__cast(other))
+ def set_committed_value(self, obj, value):
+ """Set an attribute value on the given instance and 'commit' it."""
+
+ state = obj._state
+ orig = state.get('original', None)
- def __contains__(self, item): return item in self.data
+ collection, user_data = self._build_collection(obj)
+ self._load_collection(obj, value or [], emit_events=False,
+ collection=collection)
+ value = user_data
- def __len__(self):
+ if orig is not None:
+ orig.commit_attribute(self, obj, value)
+ # remove per-instance callable, if any
+ state.pop(('callable', self), None)
+ obj.__dict__[self.key] = value
+ return value
+
+ def _build_collection(self, obj):
+ user_data = self.collection_factory()
+ collection = collections.CollectionAdapter(self, obj, user_data)
+ return collection, user_data
+
+ def _load_collection(self, obj, values, emit_events=True, collection=None):
+ collection = collection or self._get_collection(obj)
+ if values is None:
+ return
+ elif emit_events:
+ for item in values:
+ collection.append_with_event(item)
+ else:
+ for item in values:
+ collection.append_without_event(item)
+
+ def _get_collection(self, obj, user_data=None):
+ if user_data is None:
+ user_data = self.get(obj)
try:
- return len(self.data)
- except TypeError:
- return len(list(self.data))
-
- def __setslice__(self, i, j, other):
- [self.__delrecord(x) for x in self.data[i:j]]
- g = [a for a in list(other) if self.__setrecord(a)]
- self.data[i:j] = g
-
- def __delslice__(self, i, j):
- for a in self.data[i:j]:
- self.__delrecord(a)
- del self.data[i:j]
-
- def insert(self, i, item):
- if self.__setrecord(item):
- self.data.insert(i, item)
-
- def pop(self, i=-1):
- item = self.data[i]
- self.__delrecord(item)
- return self.data.pop(i)
-
- def remove(self, item):
- self.__delrecord(item)
- self.data.remove(item)
-
- def discard(self, item):
- if item in self.data:
- self.__delrecord(item)
- self.data.remove(item)
-
- def extend(self, item_list):
- for item in item_list:
- self.append(item)
-
- def __add__(self, other):
- raise NotImplementedError()
-
- def __radd__(self, other):
- raise NotImplementedError()
-
- def __iadd__(self, other):
- raise NotImplementedError()
-
-class AttributeExtension(object):
- """An abstract class which specifies `append`, `delete`, and `set`
- event handlers to be attached to an object property.
- """
-
- def append(self, event, obj, child):
- pass
-
- def delete(self, event, obj, child):
- pass
+ return getattr(user_data, '_sa_adapter')
+ except AttributeError:
+ collections.CollectionAdapter(self, obj, user_data)
+ return getattr(user_data, '_sa_adapter')
- def set(self, event, obj, child, oldchild):
- pass
-class GenericBackrefExtension(AttributeExtension):
+class GenericBackrefExtension(interfaces.AttributeExtension):
"""An extension which synchronizes a two-way relationship.
A typical two-way relationship is a parent object containing a
@@ -639,19 +516,19 @@ class GenericBackrefExtension(AttributeExtension):
def __init__(self, key):
self.key = key
- def set(self, event, obj, child, oldchild):
+ def set(self, obj, child, oldchild, initiator):
if oldchild is child:
return
if oldchild is not None:
- getattr(oldchild.__class__, self.key).remove(event, oldchild, obj)
+ getattr(oldchild.__class__, self.key).remove(oldchild, obj, initiator)
if child is not None:
- getattr(child.__class__, self.key).append(event, child, obj)
+ getattr(child.__class__, self.key).append(child, obj, initiator)
- def append(self, event, obj, child):
- getattr(child.__class__, self.key).append(event, child, obj)
+ def append(self, obj, child, initiator):
+ getattr(child.__class__, self.key).append(child, obj, initiator)
- def delete(self, event, obj, child):
- getattr(child.__class__, self.key).remove(event, child, obj)
+ def remove(self, obj, child, initiator):
+ getattr(child.__class__, self.key).remove(child, obj, initiator)
class CommittedState(object):
"""Store the original state of an object when the ``commit()`
@@ -673,7 +550,7 @@ class CommittedState(object):
"""
if value is CommittedState.NO_VALUE:
- if obj.__dict__.has_key(attr.key):
+ if attr.key in obj.__dict__:
value = obj.__dict__[attr.key]
if value is not CommittedState.NO_VALUE:
self.data[attr.key] = attr.copy(value)
@@ -690,10 +567,13 @@ class CommittedState(object):
def rollback(self, manager, obj):
for attr in manager.managed_attributes(obj.__class__):
if self.data.has_key(attr.key):
- if attr.uselist:
- obj.__dict__[attr.key][:] = self.data[attr.key]
- else:
+ if not isinstance(attr, InstrumentedCollectionAttribute):
obj.__dict__[attr.key] = self.data[attr.key]
+ else:
+ collection = attr._get_collection(obj)
+ collection.clear_without_event()
+ for item in self.data[attr.key]:
+ collection.append_without_event(item)
else:
del obj.__dict__[attr.key]
@@ -718,17 +598,15 @@ class AttributeHistory(object):
else:
original = None
- if attr.uselist:
+ if isinstance(attr, InstrumentedCollectionAttribute):
self._current = current
- else:
- self._current = [current]
- if attr.uselist:
s = util.Set(original or [])
self._added_items = []
self._unchanged_items = []
self._deleted_items = []
if current:
- for a in current:
+ collection = attr._get_collection(obj, current)
+ for a in collection:
if a in s:
self._unchanged_items.append(a)
else:
@@ -737,6 +615,7 @@ class AttributeHistory(object):
if a not in self._unchanged_items:
self._deleted_items.append(a)
else:
+ self._current = [current]
if attr.is_equal(current, original):
self._unchanged_items = [current]
self._added_items = []
@@ -748,7 +627,6 @@ class AttributeHistory(object):
else:
self._deleted_items = []
self._unchanged_items = []
- #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
def __iter__(self):
return iter(self._current)
@@ -766,24 +644,13 @@ class AttributeHistory(object):
return self._deleted_items
def hasparent(self, obj):
- """Deprecated. This should be called directly from the
- appropriate ``InstrumentedAttribute`` object.
+ """Deprecated. This should be called directly from the appropriate ``InstrumentedAttribute`` object.
"""
return self.attr.hasparent(obj)
class AttributeManager(object):
- """Allow the instrumentation of object attributes.
-
- ``AttributeManager`` is stateless, but can be overridden by
- subclasses to redefine some of its factory operations. Also be
- aware ``AttributeManager`` will cache attributes for a given
- class, allowing not to determine those for each objects (used in
- ``managed_attributes()`` and
- ``noninherited_managed_attributes()``). This cache is cleared for
- a given class while calling ``register_attribute()``, and can be
- cleared using ``clear_attribute_cache()``.
- """
+ """Allow the instrumentation of object attributes."""
def __init__(self):
# will cache attributes, indexed by class objects
@@ -827,7 +694,7 @@ class AttributeManager(object):
o._state['modified'] = False
def managed_attributes(self, class_):
- """Return an iterator of all ``InstrumentedAttribute`` objects
+ """Return a list of all ``InstrumentedAttribute`` objects
associated with the given class.
"""
@@ -878,7 +745,7 @@ class AttributeManager(object):
"""Return an attribute of the given name from the given object.
If the attribute is a scalar, return it as a single-item list,
- otherwise return the list based attribute.
+ otherwise return a collection based attribute.
If the attribute's value is to be produced by an unexecuted
callable, the callable will only be executed if the given
@@ -887,10 +754,10 @@ class AttributeManager(object):
attr = getattr(obj.__class__, key)
x = attr.get(obj, passive=passive)
- if x is InstrumentedAttribute.PASSIVE_NORESULT:
+ if x is PASSIVE_NORESULT:
return []
- elif attr.uselist:
- return x
+ elif isinstance(attr, InstrumentedCollectionAttribute):
+ return list(attr._get_collection(obj, x))
else:
return [x]
@@ -921,7 +788,7 @@ class AttributeManager(object):
by ``trigger_history()``.
"""
- return obj._state.has_key('trigger')
+ return 'trigger' in obj._state
def reset_instance_attribute(self, obj, key):
"""Remove any per-instance callable functions corresponding to
@@ -946,10 +813,9 @@ class AttributeManager(object):
"""Return True if the given `key` correponds to an
instrumented property on the given class.
"""
-
return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute)
- def init_instance_attribute(self, obj, key, uselist, callable_=None, **kwargs):
+ def init_instance_attribute(self, obj, key, callable_=None):
"""Initialize an attribute on an instance to either a blank
value, cancelling out any class- or instance-level callables
that were present, or if a `callable` is supplied set the
@@ -964,7 +830,24 @@ class AttributeManager(object):
events back to this ``AttributeManager``.
"""
- return InstrumentedAttribute(self, key, uselist, callable_, typecallable, **kwargs)
+ if uselist:
+ return InstrumentedCollectionAttribute(class_, self, key,
+ callable_,
+ typecallable,
+ **kwargs)
+ else:
+ return InstrumentedScalarAttribute(class_, self, key, callable_,
+ **kwargs)
+
+ def get_attribute(self, obj_or_cls, key):
+ """Register an attribute at the class level to be instrumented
+ for all instances of the class.
+ """
+
+ if isinstance(obj_or_cls, type):
+ return getattr(obj_or_cls, key)
+ else:
+ return getattr(obj_or_cls.__class__, key)
def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
"""Register an attribute at the class level to be instrumented
@@ -973,10 +856,9 @@ class AttributeManager(object):
# firt invalidate the cache for the given class
# (will be reconstituted as needed, while getting managed attributes)
- self._inherited_attribute_cache.pop(class_,None)
- self._noninherited_attribute_cache.pop(class_,None)
+ self._inherited_attribute_cache.pop(class_, None)
+ self._noninherited_attribute_cache.pop(class_, None)
- #print self, "register attribute", key, "for class", class_
if not hasattr(class_, '_state'):
def _get_state(self):
if not hasattr(self, '_sa_attr_state'):
@@ -987,4 +869,12 @@ class AttributeManager(object):
typecallable = kwargs.pop('typecallable', None)
if isinstance(typecallable, InstrumentedAttribute):
typecallable = None
- setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs))
+ setattr(class_, key, self.create_prop(class_, key, uselist, callable_,
+ typecallable=typecallable, **kwargs))
+
+ def init_collection(self, instance, key):
+ """Initialize a collection attribute and return the collection adapter."""
+
+ attr = self.get_attribute(instance, key)
+ user_data = attr.initialize(instance)
+ return attr._get_collection(instance, user_data)
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
new file mode 100644
index 000000000..7ade882f5
--- /dev/null
+++ b/lib/sqlalchemy/orm/collections.py
@@ -0,0 +1,1182 @@
+"""Support for collections of mapped entities.
+
+The collections package supplies the machinery used to inform the ORM of
+collection membership changes. An instrumentation via decoration approach is
+used, allowing arbitrary types (including built-ins) to be used as entity
+collections without requiring inheritance from a base class.
+
+Instrumentation decoration relays membership change events to the
+``InstrumentedCollectionAttribute`` that is currently managing the collection.
+The decorators observe function call arguments and return values, tracking
+entities entering or leaving the collection. Two decorator approaches are
+provided. One is a bundle of generic decorators that map function arguments
+and return values to events::
+
+ from sqlalchemy.orm.collections import collection
+ class MyClass(object):
+ # ...
+
+ @collection.adds(1)
+ def store(self, item):
+ self.data.append(item)
+
+ @collection.removes_return()
+ def pop(self):
+ return self.data.pop()
+
+
+The second approach is a bundle of targeted decorators that wrap appropriate
+append and remove notifiers around the mutation methods present in the
+standard Python ``list``, ``set`` and ``dict`` interfaces. These could be
+specified in terms of generic decorator recipes, but are instead hand-tooled for
+increased efficiency. The targeted decorators occasionally implement
+adapter-like behavior, such as mapping bulk-set methods (``extend``, ``update``,
+``__setslice``, etc.) into the series of atomic mutation events that the ORM
+requires.
+
+The targeted decorators are used internally for automatic instrumentation of
+entity collection classes. Every collection class goes through a
+transformation process roughly like so:
+
+1. If the class is a built-in, substitute a trivial sub-class
+2. Is this class already instrumented?
+3. Add in generic decorators
+4. Sniff out the collection interface through duck-typing
+5. Add targeted decoration to any undecorated interface method
+
+This process modifies the class at runtime, decorating methods and adding some
+bookkeeping properties. This isn't possible (or desirable) for built-in
+classes like ``list``, so trivial sub-classes are substituted to hold
+decoration::
+
+ class InstrumentedList(list):
+ pass
+
+Collection classes can be specified in ``relation(collection_class=)`` as
+types or a function that returns an instance. Collection classes are
+inspected and instrumented during the mapper compilation phase. The
+collection_class callable will be executed once to produce a specimen
+instance, and the type of that specimen will be instrumented. Functions that
+return built-in types like ``lists`` will be adapted to produce instrumented
+instances.
+
+When extending a known type like ``list``, additional decorations are not
+generally not needed. Odds are, the extension method will delegate to a
+method that's already instrumented. For example::
+
+ class QueueIsh(list):
+ def push(self, item):
+ self.append(item)
+ def shift(self):
+ return self.pop(0)
+
+There's no need to decorate these methods. ``append`` and ``pop`` are already
+instrumented as part of the ``list`` interface. Decorating them would fire
+duplicate events, which should be avoided.
+
+The targeted decoration tries not to rely on other methods in the underlying
+collection class, but some are unavoidable. Many depend on 'read' methods
+being present to properly instrument a 'write', for example, ``__setitem__``
+needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also
+reimplemented in terms of atomic appends and removes, so the ``extend``
+decoration will actually perform many ``append`` operations and not call the
+underlying method at all.
+
+Tight control over bulk operation and the firing of events is also possible by
+implementing the instrumentation internally in your methods. The basic
+instrumentation package works under the general assumption that collection
+mutation will not raise unusual exceptions. If you want to closely
+orchestrate append and remove events with exception management, internal
+instrumentation may be the answer. Within your method,
+``collection_adapter(self)`` will retrieve an object that you can use for
+explicit control over triggering append and remove events.
+
+The owning object and InstrumentedCollectionAttribute are also reachable
+through the adapter, allowing for some very sophisticated behavior.
+"""
+
+import copy, inspect, sys, weakref
+
+from sqlalchemy import exceptions, schema, util as sautil
+from sqlalchemy.orm import mapper
+
+try:
+ from threading import Lock
+except:
+ from dummy_threading import Lock
+try:
+ from operator import attrgetter
+except:
+ def attrgetter(attribute):
+ return lambda value: getattr(value, attribute)
+
+
+__all__ = ['collection', 'collection_adapter',
+ 'mapped_collection', 'column_mapped_collection',
+ 'attribute_mapped_collection']
+
+def column_mapped_collection(mapping_spec):
+ """A dictionary-based collection type with column-based keying.
+
+ Returns a MappedCollection factory with a keying function generated
+ from mapping_spec, which may be a Column or a sequence of Columns.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+ """
+
+ if isinstance(mapping_spec, schema.Column):
+ def keyfunc(value):
+ m = mapper.object_mapper(value)
+ return m.get_attr_by_column(value, mapping_spec)
+ else:
+ cols = []
+ for c in mapping_spec:
+ if not isinstance(c, schema.Column):
+ raise exceptions.ArgumentError(
+ "mapping_spec tuple may only contain columns")
+ cols.append(c)
+ mapping_spec = tuple(cols)
+ def keyfunc(value):
+ m = mapper.object_mapper(value)
+ return tuple([m.get_attr_by_column(value, c) for c in mapping_spec])
+ return lambda: MappedCollection(keyfunc)
+
+def attribute_mapped_collection(attr_name):
+ """A dictionary-based collection type with attribute-based keying.
+
+ Returns a MappedCollection factory with a keying based on the
+ 'attr_name' attribute of entities in the collection.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+ """
+
+ return lambda: MappedCollection(attrgetter(attr_name))
+
+
+def mapped_collection(keyfunc):
+ """A dictionary-based collection type with arbitrary keying.
+
+ Returns a MappedCollection factory with a keying function generated
+ from keyfunc, a callable that takes an entity and returns a key value.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+ """
+
+ return lambda: MappedCollection(keyfunc)
+
+class collection(object):
+ """Decorators for entity collection classes.
+
+ The decorators fall into two groups: annotations and interception recipes.
+
+ The annotating decorators (appender, remover, iterator,
+ internally_instrumented, on_link) indicate the method's purpose and take no
+ arguments. They are not written with parens::
+
+ @collection.appender
+ def append(self, append): ...
+
+ The recipe decorators all require parens, even those that take no
+ arguments::
+
+ @collection.adds('entity'):
+ def insert(self, position, entity): ...
+
+ @collection.removes_return()
+ def popitem(self): ...
+
+ Decorators can be specified in long-hand for Python 2.3, or with
+ the class-level dict attribute '__instrumentation__'- see the source
+ for details.
+ """
+
+ # Bundled as a class solely for ease of use: packaging, doc strings,
+ # importability.
+
+ def appender(cls, fn):
+ """Tag the method as the collection appender.
+
+ The appender method is called with one positional argument: the value
+ to append. The method will be automatically decorated with 'adds(1)'
+ if not already decorated::
+
+ @collection.appender
+ def add(self, append): ...
+
+ # or, equivalently
+ @collection.appender
+ @collection.adds(1)
+ def add(self, append): ...
+
+ # for mapping type, an 'append' may kick out a previous value
+ # that occupies that slot. consider d['a'] = 'foo'- any previous
+ # value in d['a'] is discarded.
+ @collection.appender
+ @collection.replaces(1)
+ def add(self, entity):
+ key = some_key_func(entity)
+ previous = None
+ if key in self:
+ previous = self[key]
+ self[key] = entity
+ return previous
+
+ If the value to append is not allowed in the collection, you may
+ raise an exception. Something to remember is that the appender
+ will be called for each object mapped by a database query. If the
+ database contains rows that violate your collection semantics, you
+ will need to get creative to fix the problem, as access via the
+ collection will not work.
+
+ If the appender method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+ """
+
+ setattr(fn, '_sa_instrument_role', 'appender')
+ return fn
+ appender = classmethod(appender)
+
+ def remover(cls, fn):
+ """Tag the method as the collection remover.
+
+ The remover method is called with one positional argument: the value
+ to remove. The method will be automatically decorated with
+ 'removes_return()' if not already decorated::
+
+ @collection.remover
+ def zap(self, entity): ...
+
+ # or, equivalently
+ @collection.remover
+ @collection.removes_return()
+ def zap(self, ): ...
+
+ If the value to remove is not present in the collection, you may
+ raise an exception or return None to ignore the error.
+
+ If the remove method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+ """
+
+ setattr(fn, '_sa_instrument_role', 'remover')
+ return fn
+ remover = classmethod(remover)
+
+ def iterator(cls, fn):
+ """Tag the method as the collection remover.
+
+ The iterator method is called with no arguments. It is expected to
+ return an iterator over all collection members::
+
+ @collection.iterator
+ def __iter__(self): ...
+ """
+
+ setattr(fn, '_sa_instrument_role', 'iterator')
+ return fn
+ iterator = classmethod(iterator)
+
+ def internally_instrumented(cls, fn):
+ """Tag the method as instrumented.
+
+ This tag will prevent any decoration from being applied to the method.
+ Use this if you are orchestrating your own calls to collection_adapter
+ in one of the basic SQLAlchemy interface methods, or to prevent
+ an automatic ABC method decoration from wrapping your implementation::
+
+ # normally an 'extend' method on a list-like class would be
+ # automatically intercepted and re-implemented in terms of
+ # SQLAlchemy events and append(). your implementation will
+ # never be called, unless:
+ @collection.internally_instrumented
+ def extend(self, items): ...
+ """
+
+ setattr(fn, '_sa_instrumented', True)
+ return fn
+ internally_instrumented = classmethod(internally_instrumented)
+
+ def on_link(cls, fn):
+ """Tag the method as a the "linked to attribute" event handler.
+
+ This optional event handler will be called when the collection class
+ is linked to or unlinked from the InstrumentedAttribute. It is
+ invoked immediately after the '_sa_adapter' property is set on
+ the instance. A single argument is passed: the collection adapter
+ that has been linked, or None if unlinking.
+ """
+
+ setattr(fn, '_sa_instrument_role', 'on_link')
+ return fn
+ on_link = classmethod(on_link)
+
+ def adds(cls, arg):
+ """Mark the method as adding an entity to the collection.
+
+ Adds "add to collection" handling to the method. The decorator argument
+ indicates which method argument holds the SQLAlchemy-relevant value.
+ Arguments can be specified positionally (i.e. integer) or by name::
+
+ @collection.adds(1)
+ def push(self, item): ...
+
+ @collection.adds('entity')
+ def do_stuff(self, thing, entity=None): ...
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
+ return fn
+ return decorator
+ adds = classmethod(adds)
+
+ def replaces(cls, arg):
+ """Mark the method as replacing an entity in the collection.
+
+ Adds "add to collection" and "remove from collection" handling to
+ the method. The decorator argument indicates which method argument
+ holds the SQLAlchemy-relevant value to be added, and return value, if
+ any will be considered the value to remove.
+
+ Arguments can be specified positionally (i.e. integer) or by name::
+
+ @collection.replaces(2)
+ def __setitem__(self, index, item): ...
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
+ setattr(fn, '_sa_instrument_after', 'fire_remove_event')
+ return fn
+ return decorator
+ replaces = classmethod(replaces)
+
+ def removes(cls, arg):
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The decorator
+ argument indicates which method argument holds the SQLAlchemy-relevant
+ value to be removed. Arguments can be specified positionally (i.e.
+ integer) or by name::
+
+ @collection.removes(1)
+ def zap(self, item): ...
+
+ For methods where the value to remove is not known at call-time, use
+ collection.removes_return.
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg))
+ return fn
+ return decorator
+ removes = classmethod(removes)
+
+ def removes_return(cls):
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The return value
+ of the method, if any, is considered the value to remove. The method
+ arguments are not inspected::
+
+ @collection.removes_return()
+ def pop(self): ...
+
+ For methods where the value to remove is known at call-time, use
+ collection.remove.
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_after', 'fire_remove_event')
+ return fn
+ return decorator
+ removes_return = classmethod(removes_return)
+
+
+# public instrumentation interface for 'internally instrumented'
+# implementations
+def collection_adapter(collection):
+ """Fetch the CollectionAdapter for a collection."""
+
+ return getattr(collection, '_sa_adapter', None)
+
+class CollectionAdapter(object):
+ """Bridges between the ORM and arbitrary Python collections.
+
+ Proxies base-level collection operations (append, remove, iterate)
+ to the underlying Python collection, and emits add/remove events for
+ entities entering or leaving the collection.
+
+ The ORM uses an CollectionAdapter exclusively for interaction with
+ entity collections.
+ """
+
+ def __init__(self, attr, owner, data):
+ self.attr = attr
+ self._owner = weakref.ref(owner)
+ self._data = weakref.ref(data)
+ self.link_to_self(data)
+
+ owner = property(lambda s: s._owner(),
+ doc="The object that owns the entity collection.")
+ data = property(lambda s: s._data(),
+ doc="The entity collection being adapted.")
+
+ def link_to_self(self, data):
+ """Link a collection to this adapter, and fire a link event."""
+
+ setattr(data, '_sa_adapter', self)
+ if hasattr(data, '_sa_on_link'):
+ getattr(data, '_sa_on_link')(self)
+
+ def unlink(self, data):
+ """Unlink a collection from any adapter, and fire a link event."""
+
+ setattr(data, '_sa_adapter', None)
+ if hasattr(data, '_sa_on_link'):
+ getattr(data, '_sa_on_link')(None)
+
+ def append_with_event(self, item, initiator=None):
+ """Add an entity to the collection, firing mutation events."""
+
+ getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator)
+
+ def append_without_event(self, item):
+ """Add or restore an entity to the collection, firing no events."""
+
+ getattr(self._data(), '_sa_appender')(item, _sa_initiator=False)
+
+ def remove_with_event(self, item, initiator=None):
+ """Remove an entity from the collection, firing mutation events."""
+
+ getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator)
+
+ def remove_without_event(self, item):
+ """Remove an entity from the collection, firing no events."""
+
+ getattr(self._data(), '_sa_remover')(item, _sa_initiator=False)
+
+ def clear_with_event(self, initiator=None):
+ """Empty the collection, firing a mutation event for each entity."""
+
+ for item in list(self):
+ self.remove_with_event(item, initiator)
+
+ def clear_without_event(self):
+ """Empty the collection, firing no events."""
+
+ for item in list(self):
+ self.remove_without_event(item)
+
+ def __iter__(self):
+ """Iterate over entities in the collection."""
+
+ return getattr(self._data(), '_sa_iterator')()
+
+ def __len__(self):
+ """Count entities in the collection."""
+
+ return len(list(getattr(self._data(), '_sa_iterator')()))
+
+ def __nonzero__(self):
+ return True
+
+ def fire_append_event(self, item, initiator=None):
+ """Notify that a entity has entered the collection.
+
+ Initiator is the InstrumentedAttribute that initiated the membership
+ mutation, and should be left as None unless you are passing along
+ an initiator value from a chained operation.
+ """
+
+ if initiator is not False and item is not None:
+ self.attr.fire_append_event(self._owner(), item, initiator)
+
+ def fire_remove_event(self, item, initiator=None):
+ """Notify that a entity has entered the collection.
+
+ Initiator is the InstrumentedAttribute that initiated the membership
+ mutation, and should be left as None unless you are passing along
+ an initiator value from a chained operation.
+ """
+
+ if initiator is not False and item is not None:
+ self.attr.fire_remove_event(self._owner(), item, initiator)
+
+ def __getstate__(self):
+ return { 'key': self.attr.key,
+ 'owner': self.owner,
+ 'data': self.data }
+
+ def __setstate__(self, d):
+ self.attr = getattr(d['owner'].__class__, d['key'])
+ self._owner = weakref.ref(d['owner'])
+ self._data = weakref.ref(d['data'])
+
+
+__instrumentation_mutex = Lock()
+def _prepare_instrumentation(factory):
+ """Prepare a callable for future use as a collection class factory.
+
+ Given a collection class factory (either a type or no-arg callable),
+ return another factory that will produce compatible instances when
+ called.
+
+ This function is responsible for converting collection_class=list
+ into the run-time behavior of collection_class=InstrumentedList.
+ """
+
+ # Convert a builtin to 'Instrumented*'
+ if factory in __canned_instrumentation:
+ factory = __canned_instrumentation[factory]
+
+ # Create a specimen
+ cls = type(factory())
+
+ # Did factory callable return a builtin?
+ if cls in __canned_instrumentation:
+ # Wrap it so that it returns our 'Instrumented*'
+ factory = __converting_factory(factory)
+ cls = factory()
+
+ # Instrument the class if needed.
+ if __instrumentation_mutex.acquire():
+ try:
+ if getattr(cls, '_sa_instrumented', None) != id(cls):
+ _instrument_class(cls)
+ finally:
+ __instrumentation_mutex.release()
+
+ return factory
+
+def __converting_factory(original_factory):
+ """Convert the type returned by collection factories on the fly.
+
+ Given a collection factory that returns a builtin type (e.g. a list),
+ return a wrapped function that converts that type to one of our
+ instrumented types.
+ """
+
+ def wrapper():
+ collection = original_factory()
+ type_ = type(collection)
+ if type_ in __canned_instrumentation:
+ # return an instrumented type initialized from the factory's
+ # collection
+ return __canned_instrumentation[type_](collection)
+ else:
+ raise exceptions.InvalidRequestError(
+ "Collection class factories must produce instances of a "
+ "single class.")
+ try:
+ # often flawed but better than nothing
+ wrapper.__name__ = "%sWrapper" % original_factory.__name__
+ wrapper.__doc__ = original_factory.__doc__
+ except:
+ pass
+ return wrapper
+
+def _instrument_class(cls):
+ """Modify methods in a class and install instrumentation."""
+
+ # FIXME: more formally document this as a decoratorless/Python 2.3
+ # option for specifying instrumentation. (likely doc'd here in code only,
+ # not in online docs.)
+ #
+ # __instrumentation__ = {
+ # 'rolename': 'methodname', # ...
+ # 'methods': {
+ # 'methodname': ('fire_{append,remove}_event', argspec,
+ # 'fire_{append,remove}_event'),
+ # 'append': ('fire_append_event', 1, None),
+ # '__setitem__': ('fire_append_event', 1, 'fire_remove_event'),
+ # 'pop': (None, None, 'fire_remove_event'),
+ # }
+ # }
+
+ # In the normal call flow, a request for any of the 3 basic collection
+ # types is transformed into one of our trivial subclasses
+ # (e.g. InstrumentedList). Catch anything else that sneaks in here...
+ if cls.__module__ == '__builtin__':
+ raise exceptions.ArgumentError(
+ "Can not instrument a built-in type. Use a "
+ "subclass, even a trivial one.")
+
+ collection_type = sautil.duck_type_collection(cls)
+ if collection_type in __interfaces:
+ roles = __interfaces[collection_type].copy()
+ decorators = roles.pop('_decorators', {})
+ else:
+ roles, decorators = {}, {}
+
+ if hasattr(cls, '__instrumentation__'):
+ roles.update(copy.deepcopy(getattr(cls, '__instrumentation__')))
+
+ methods = roles.pop('methods', {})
+
+ for name in dir(cls):
+ method = getattr(cls, name)
+ if not callable(method):
+ continue
+
+ # note role declarations
+ if hasattr(method, '_sa_instrument_role'):
+ role = method._sa_instrument_role
+ assert role in ('appender', 'remover', 'iterator', 'on_link')
+ roles[role] = name
+
+ # transfer instrumentation requests from decorated function
+ # to the combined queue
+ before, after = None, None
+ if hasattr(method, '_sa_instrument_before'):
+ op, argument = method._sa_instrument_before
+ assert op in ('fire_append_event', 'fire_remove_event')
+ before = op, argument
+ if hasattr(method, '_sa_instrument_after'):
+ op = method._sa_instrument_after
+ assert op in ('fire_append_event', 'fire_remove_event')
+ after = op
+ if before:
+ methods[name] = before[0], before[1], after
+ elif after:
+ methods[name] = None, None, after
+
+ # apply ABC auto-decoration to methods that need it
+ for method, decorator in decorators.items():
+ fn = getattr(cls, method, None)
+ if fn and method not in methods and not hasattr(fn, '_sa_instrumented'):
+ setattr(cls, method, decorator(fn))
+
+ # ensure all roles are present, and apply implicit instrumentation if
+ # needed
+ if 'appender' not in roles or not hasattr(cls, roles['appender']):
+ raise exceptions.ArgumentError(
+ "Type %s must elect an appender method to be "
+ "a collection class" % cls.__name__)
+ elif (roles['appender'] not in methods and
+ not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')):
+ methods[roles['appender']] = ('fire_append_event', 1, None)
+
+ if 'remover' not in roles or not hasattr(cls, roles['remover']):
+ raise exceptions.ArgumentError(
+ "Type %s must elect a remover method to be "
+ "a collection class" % cls.__name__)
+ elif (roles['remover'] not in methods and
+ not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')):
+ methods[roles['remover']] = ('fire_remove_event', 1, None)
+
+ if 'iterator' not in roles or not hasattr(cls, roles['iterator']):
+ raise exceptions.ArgumentError(
+ "Type %s must elect an iterator method to be "
+ "a collection class" % cls.__name__)
+
+ # apply ad-hoc instrumentation from decorators, class-level defaults
+ # and implicit role declarations
+ for method, (before, argument, after) in methods.items():
+ setattr(cls, method,
+ _instrument_membership_mutator(getattr(cls, method),
+ before, argument, after))
+ # intern the role map
+ for role, method in roles.items():
+ setattr(cls, '_sa_%s' % role, getattr(cls, method))
+
+ setattr(cls, '_sa_instrumented', id(cls))
+
+def _instrument_membership_mutator(method, before, argument, after):
+ """Route method args and/or return value through the collection adapter."""
+
+ if type(argument) is int:
+ def wrapper(*args, **kw):
+ if before and len(args) < argument:
+ raise exceptions.ArgumentError(
+ 'Missing argument %i' % argument)
+ initiator = kw.pop('_sa_initiator', None)
+ if initiator is False:
+ executor = None
+ else:
+ executor = getattr(args[0], '_sa_adapter', None)
+
+ if before and executor:
+ getattr(executor, before)(args[argument], initiator)
+
+ if not after or not executor:
+ return method(*args, **kw)
+ else:
+ res = method(*args, **kw)
+ if res is not None:
+ getattr(executor, after)(res, initiator)
+ return res
+ else:
+ def wrapper(*args, **kw):
+ if before:
+ vals = inspect.getargvalues(inspect.currentframe())
+ if argument in kw:
+ value = kw[argument]
+ else:
+ positional = inspect.getargspec(method)[0]
+ pos = positional.index(argument)
+ if pos == -1:
+ raise exceptions.ArgumentError('Missing argument %s' %
+ argument)
+ else:
+ value = args[pos]
+
+ initiator = kw.pop('_sa_initiator', None)
+ if initiator is False:
+ executor = None
+ else:
+ executor = getattr(args[0], '_sa_adapter', None)
+
+ if before and executor:
+ getattr(executor, before)(value, initiator)
+
+ if not after or not executor:
+ return method(*args, **kw)
+ else:
+ res = method(*args, **kw)
+ if res is not None:
+ getattr(executor, after)(res, initiator)
+ return res
+ try:
+ wrapper._sa_instrumented = True
+ wrapper.__name__ = method.__name__
+ wrapper.__doc__ = method.__doc__
+ except:
+ pass
+ return wrapper
+
+def __set(collection, item, _sa_initiator=None):
+ """Run set events, may eventually be inlined into decorators."""
+
+ if _sa_initiator is not False and item is not None:
+ executor = getattr(collection, '_sa_adapter', None)
+ if executor:
+ getattr(executor, 'fire_append_event')(item, _sa_initiator)
+
+def __del(collection, item, _sa_initiator=None):
+ """Run del events, may eventually be inlined into decorators."""
+
+ if _sa_initiator is not False and item is not None:
+ executor = getattr(collection, '_sa_adapter', None)
+ if executor:
+ getattr(executor, 'fire_remove_event')(item, _sa_initiator)
+
+def _list_decorators():
+ """Hand-turned instrumentation wrappers that can decorate any list-like
+ class."""
+
+ def _tidy(fn):
+ setattr(fn, '_sa_instrumented', True)
+ fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__')
+
+ def append(fn):
+ def append(self, item, _sa_initiator=None):
+ # FIXME: example of fully inlining __set and adapter.fire
+ # for critical path
+ if _sa_initiator is not False and item is not None:
+ executor = getattr(self, '_sa_adapter', None)
+ if executor:
+ executor.attr.fire_append_event(executor._owner(),
+ item, _sa_initiator)
+ fn(self, item)
+ _tidy(append)
+ return append
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ fn(self, value)
+ __del(self, value, _sa_initiator)
+ _tidy(remove)
+ return remove
+
+ def insert(fn):
+ def insert(self, index, value):
+ __set(self, value)
+ fn(self, index, value)
+ _tidy(insert)
+ return insert
+
+ def __setitem__(fn):
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ existing = self[index]
+ if existing is not None:
+ __del(self, existing)
+ __set(self, value)
+ fn(self, index, value)
+ else:
+ # slice assignment requires __delitem__, insert, __len__
+ if index.stop is None:
+ stop = 0
+ elif index.stop < 0:
+ stop = len(self) + index.stop
+ else:
+ stop = index.stop
+ step = index.step or 1
+ rng = range(index.start or 0, stop, step)
+ if step == 1:
+ for i in rng:
+ del self[index.start]
+ i = index.start
+ for item in value:
+ self.insert(i, item)
+ i += 1
+ else:
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value),
+ len(rng)))
+ for i, item in zip(rng, value):
+ self.__setitem__(i, item)
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, index):
+ if not isinstance(index, slice):
+ item = self[index]
+ __del(self, item)
+ fn(self, index)
+ else:
+ # slice deletion requires __getslice__ and a slice-groking
+ # __getitem__ for stepped deletion
+ # note: not breaking this into atomic dels
+ for item in self[index]:
+ __del(self, item)
+ fn(self, index)
+ _tidy(__delitem__)
+ return __delitem__
+
+ def __setslice__(fn):
+ def __setslice__(self, start, end, values):
+ for value in self[start:end]:
+ __del(self, value)
+ for value in values:
+ __set(self, value)
+ fn(self, start, end, values)
+ _tidy(__setslice__)
+ return __setslice__
+
+ def __delslice__(fn):
+ def __delslice__(self, start, end):
+ for value in self[start:end]:
+ __del(self, value)
+ fn(self, start, end)
+ _tidy(__delslice__)
+ return __delslice__
+
+ def extend(fn):
+ def extend(self, iterable):
+ for value in iterable:
+ self.append(value)
+ _tidy(extend)
+ return extend
+
+ def pop(fn):
+ def pop(self, index=-1):
+ item = fn(self, index)
+ __del(self, item)
+ return item
+ _tidy(pop)
+ return pop
+
+ l = locals().copy()
+ l.pop('_tidy')
+ return l
+
+def _dict_decorators():
+ """Hand-turned instrumentation wrappers that can decorate any dict-like
+ mapping class."""
+
+ def _tidy(fn):
+ setattr(fn, '_sa_instrumented', True)
+ fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__')
+
+ Unspecified=object()
+
+ def __setitem__(fn):
+ def __setitem__(self, key, value, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ __set(self, value, _sa_initiator)
+ fn(self, key, value)
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, key, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ fn(self, key)
+ _tidy(__delitem__)
+ return __delitem__
+
+ def clear(fn):
+ def clear(self):
+ for key in self:
+ __del(self, self[key])
+ fn(self)
+ _tidy(clear)
+ return clear
+
+ def pop(fn):
+ def pop(self, key, default=Unspecified):
+ if key in self:
+ __del(self, self[key])
+ if default is Unspecified:
+ return fn(self, key)
+ else:
+ return fn(self, key, default)
+ _tidy(pop)
+ return pop
+
+ def popitem(fn):
+ def popitem(self):
+ item = fn(self)
+ __del(self, item[1])
+ return item
+ _tidy(popitem)
+ return popitem
+
+ def setdefault(fn):
+ def setdefault(self, key, default=None):
+ if key not in self:
+ self.__setitem__(key, default)
+ return default
+ else:
+ return self.__getitem__(key)
+ _tidy(setdefault)
+ return setdefault
+
+ if sys.version_info < (2, 4):
+ def update(fn):
+ def update(self, other):
+ for key in other.keys():
+ if not self.has_key(key) or self[key] is not other[key]:
+ self[key] = other[key]
+ _tidy(update)
+ return update
+ else:
+ def update(fn):
+ def update(self, __other=Unspecified, **kw):
+ if __other is not Unspecified:
+ if hasattr(__other, 'keys'):
+ for key in __other.keys():
+ if key not in self or self[key] is not __other[key]:
+ self[key] = __other[key]
+ else:
+ for key, value in __other:
+ if key not in self or self[key] is not value:
+ self[key] = value
+ for key in kw:
+ if key not in self or self[key] is not kw[key]:
+ self[key] = kw[key]
+ _tidy(update)
+ return update
+
+ l = locals().copy()
+ l.pop('_tidy')
+ l.pop('Unspecified')
+ return l
+
+def _set_decorators():
+ """Hand-turned instrumentation wrappers that can decorate any set-like
+ sequence class."""
+
+ def _tidy(fn):
+ setattr(fn, '_sa_instrumented', True)
+ fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__')
+
+ Unspecified=object()
+
+ def add(fn):
+ def add(self, value, _sa_initiator=None):
+ __set(self, value, _sa_initiator)
+ fn(self, value)
+ _tidy(add)
+ return add
+
+ def discard(fn):
+ def discard(self, value, _sa_initiator=None):
+ if value in self:
+ __del(self, value, _sa_initiator)
+ fn(self, value)
+ _tidy(discard)
+ return discard
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ if value in self:
+ __del(self, value, _sa_initiator)
+ fn(self, value)
+ _tidy(remove)
+ return remove
+
+ def pop(fn):
+ def pop(self):
+ item = fn(self)
+ __del(self, item)
+ return item
+ _tidy(pop)
+ return pop
+
+ def clear(fn):
+ def clear(self):
+ for item in list(self):
+ self.remove(item)
+ _tidy(clear)
+ return clear
+
+ def update(fn):
+ def update(self, value):
+ for item in value:
+ if item not in self:
+ self.add(item)
+ _tidy(update)
+ return update
+ __ior__ = update
+
+ def difference_update(fn):
+ def difference_update(self, value):
+ for item in value:
+ self.discard(item)
+ _tidy(difference_update)
+ return difference_update
+ __isub__ = difference_update
+
+ def intersection_update(fn):
+ def intersection_update(self, other):
+ want, have = self.intersection(other), sautil.Set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ _tidy(intersection_update)
+ return intersection_update
+ __iand__ = intersection_update
+
+ def symmetric_difference_update(fn):
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), sautil.Set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ _tidy(symmetric_difference_update)
+ return symmetric_difference_update
+ __ixor__ = symmetric_difference_update
+
+ l = locals().copy()
+ l.pop('_tidy')
+ l.pop('Unspecified')
+ return l
+
+
+class InstrumentedList(list):
+ """An instrumented version of the built-in list."""
+
+ __instrumentation__ = {
+ 'appender': 'append',
+ 'remover': 'remove',
+ 'iterator': '__iter__', }
+
+class InstrumentedSet(sautil.Set):
+ """An instrumented version of the built-in set (or Set)."""
+
+ __instrumentation__ = {
+ 'appender': 'add',
+ 'remover': 'remove',
+ 'iterator': '__iter__', }
+
+class InstrumentedDict(dict):
+ """An instrumented version of the built-in dict."""
+
+ __instrumentation__ = {
+ 'iterator': 'itervalues', }
+
+__canned_instrumentation = {
+ list: InstrumentedList,
+ sautil.Set: InstrumentedSet,
+ dict: InstrumentedDict,
+ }
+
+__interfaces = {
+ list: { 'appender': 'append',
+ 'remover': 'remove',
+ 'iterator': '__iter__',
+ '_decorators': _list_decorators(), },
+ sautil.Set: { 'appender': 'add',
+ 'remover': 'remove',
+ 'iterator': '__iter__',
+ '_decorators': _set_decorators(), },
+ # decorators are required for dicts and object collections.
+ dict: { 'iterator': 'itervalues',
+ '_decorators': _dict_decorators(), },
+ # < 0.4 compatible naming, deprecated- use decorators instead.
+ None: { }
+ }
+
+
+class MappedCollection(dict):
+ """A basic dictionary-based collection class.
+
+ Extends dict with the minimal bag semantics that collection classes require.
+ ``set`` and ``remove`` are implemented in terms of a keying function: any
+ callable that takes an object and returns an object for use as a dictionary
+ key.
+ """
+
+ def __init__(self, keyfunc):
+ """Create a new collection with keying provided by keyfunc.
+
+ keyfunc may be any callable any callable that takes an object and
+ returns an object for use as a dictionary key.
+
+ The keyfunc will be called every time the ORM needs to add a member by
+ value-only (such as when loading instances from the database) or remove
+ a member. The usual cautions about dictionary keying apply-
+ ``keyfunc(object)`` should return the same output for the life of the
+ collection. Keying based on mutable properties can result in
+ unreachable instances "lost" in the collection.
+ """
+ self.keyfunc = keyfunc
+
+ def set(self, value, _sa_initiator=None):
+ """Add an item to the collection, with a key provided by this instance's keyfunc."""
+
+ key = self.keyfunc(value)
+ self.__setitem__(key, value, _sa_initiator)
+ set = collection.internally_instrumented(set)
+ set = collection.appender(set)
+
+ def remove(self, value, _sa_initiator=None):
+ """Remove an item from the collection by value, consulting this instance's keyfunc for the key."""
+
+ key = self.keyfunc(value)
+ # Let self[key] raise if key is not in this collection
+ if self[key] != value:
+ raise exceptions.InvalidRequestError(
+ "Can not remove '%s': collection holds '%s' for key '%s'. "
+ "Possible cause: is the MappedCollection key function "
+ "based on mutable properties or properties that only obtain "
+ "values after flush?" %
+ (value, self[key], key))
+ self.__delitem__(key, _sa_initiator)
+ remove = collection.internally_instrumented(remove)
+ remove = collection.remover(remove)
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index 54b043b32..c06db6963 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -6,8 +6,8 @@
"""Bridge the ``PropertyLoader`` (i.e. a ``relation()``) and the
-``UOWTransaction`` together to allow processing of scalar- and
-list-based dependencies at flush time.
+``UOWTransaction`` together to allow processing of relation()-based
+ dependencies at flush time.
"""
from sqlalchemy.orm import sync
@@ -366,7 +366,7 @@ class ManyToManyDP(DependencyProcessor):
if len(secondary_delete):
secondary_delete.sort()
# TODO: precompile the delete/insert queries?
- statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type=c.type) for c in self.secondary.c if c.key in associationrow]))
+ statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
result = connection.execute(statement, secondary_delete)
if result.supports_sane_rowcount() and result.rowcount != len(secondary_delete):
raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (result.rowcount, len(secondary_delete)))
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index a9a26b57f..aeb8a23fa 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -5,7 +5,205 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util, logging
+from sqlalchemy import util, logging, sql
+
+# returned by a MapperExtension method to indicate a "do nothing" response
+EXT_PASS = object()
+
+class MapperExtension(object):
+ """Base implementation for an object that provides overriding
+ behavior to various Mapper functions. For each method in
+ MapperExtension, a result of EXT_PASS indicates the functionality
+ is not overridden.
+ """
+
+
+ def init_instance(self, mapper, class_, instance, args, kwargs):
+ return EXT_PASS
+
+ def init_failed(self, mapper, class_, instance, args, kwargs):
+ return EXT_PASS
+
+ def get_session(self):
+ """Retrieve a contextual Session instance with which to
+ register a new object.
+
+ Note: this is not called if a session is provided with the
+ `__init__` params (i.e. `_sa_session`).
+ """
+
+ return EXT_PASS
+
+ def load(self, query, *args, **kwargs):
+ """Override the `load` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.load()`` if the value is anything other than EXT_PASS.
+ """
+
+ return EXT_PASS
+
+ def get(self, query, *args, **kwargs):
+ """Override the `get` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.get()`` if the value is anything other than EXT_PASS.
+ """
+
+ return EXT_PASS
+
+ def get_by(self, query, *args, **kwargs):
+ """Override the `get_by` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.get_by()`` if the value is anything other than
+ EXT_PASS.
+
+ DEPRECATED.
+ """
+
+ return EXT_PASS
+
+ def select_by(self, query, *args, **kwargs):
+ """Override the `select_by` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.select_by()`` if the value is anything other than
+ EXT_PASS.
+
+ DEPRECATED.
+ """
+
+ return EXT_PASS
+
+ def select(self, query, *args, **kwargs):
+ """Override the `select` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.select()`` if the value is anything other than
+ EXT_PASS.
+
+ DEPRECATED.
+ """
+
+ return EXT_PASS
+
+
+ def translate_row(self, mapper, context, row):
+ """Perform pre-processing on the given result row and return a
+ new row instance.
+
+ This is called as the very first step in the ``_instance()``
+ method.
+ """
+
+ return EXT_PASS
+
+ def create_instance(self, mapper, selectcontext, row, class_):
+ """Receive a row when a new object instance is about to be
+ created from that row.
+
+ The method can choose to create the instance itself, or it can
+ return None to indicate normal object creation should take
+ place.
+
+ mapper
+ The mapper doing the operation
+
+ selectcontext
+ SelectionContext corresponding to the instances() call
+
+ row
+ The result row from the database
+
+ class\_
+ The class we are mapping.
+ """
+
+ return EXT_PASS
+
+ def append_result(self, mapper, selectcontext, row, instance, result, **flags):
+ """Receive an object instance before that instance is appended
+ to a result list.
+
+ If this method returns EXT_PASS, result appending will proceed
+ normally. if this method returns any other value or None,
+ result appending will not proceed for this instance, giving
+ this extension an opportunity to do the appending itself, if
+ desired.
+
+ mapper
+ The mapper doing the operation.
+
+ selectcontext
+ SelectionContext corresponding to the instances() call.
+
+ row
+ The result row from the database.
+
+ instance
+ The object instance to be appended to the result.
+
+ result
+ List to which results are being appended.
+
+ \**flags
+ extra information about the row, same as criterion in
+ `create_row_processor()` method of [sqlalchemy.orm.interfaces#MapperProperty]
+ """
+
+ return EXT_PASS
+
+ def populate_instance(self, mapper, selectcontext, row, instance, **flags):
+ """Receive a newly-created instance before that instance has
+ its attributes populated.
+
+ The normal population of attributes is according to each
+ attribute's corresponding MapperProperty (which includes
+ column-based attributes as well as relationships to other
+ classes). If this method returns EXT_PASS, instance
+ population will proceed normally. If any other value or None
+ is returned, instance population will not proceed, giving this
+ extension an opportunity to populate the instance itself, if
+ desired.
+ """
+
+ return EXT_PASS
+
+ def before_insert(self, mapper, connection, instance):
+ """Receive an object instance before that instance is INSERTed
+ into its table.
+
+ This is a good place to set up primary key values and such
+ that aren't handled otherwise.
+ """
+
+ return EXT_PASS
+
+ def before_update(self, mapper, connection, instance):
+ """Receive an object instance before that instance is UPDATEed."""
+
+ return EXT_PASS
+
+ def after_update(self, mapper, connection, instance):
+ """Receive an object instance after that instance is UPDATEed."""
+
+ return EXT_PASS
+
+ def after_insert(self, mapper, connection, instance):
+ """Receive an object instance after that instance is INSERTed."""
+
+ return EXT_PASS
+
+ def before_delete(self, mapper, connection, instance):
+ """Receive an object instance before that instance is DELETEed."""
+
+ return EXT_PASS
+
+ def after_delete(self, mapper, connection, instance):
+ """Receive an object instance after that instance is DELETEed."""
+
+ return EXT_PASS
class MapperProperty(object):
"""Manage the relationship of a ``Mapper`` to a single class
@@ -15,22 +213,61 @@ class MapperProperty(object):
"""
def setup(self, querycontext, **kwargs):
- """Called when a statement is being constructed."""
+ """Called by Query for the purposes of constructing a SQL statement.
+
+ Each MapperProperty associated with the target mapper processes the
+ statement referenced by the query context, adding columns and/or
+ criterion as appropriate.
+ """
pass
- def execute(self, selectcontext, instance, row, identitykey, isnew):
- """Called when the mapper receives a row.
-
- `instance` is the parent instance corresponding to the `row`.
+ def create_row_processor(self, selectcontext, mapper, row):
+ """return a 2-tuple consiting of a row processing function and an instance post-processing function.
+
+ Input arguments are the query.SelectionContext and the *first*
+ applicable row of a result set obtained within query.Query.instances(), called
+ only the first time a particular mapper.populate_instance() is invoked for the
+ overal result.
+
+ The settings contained within the SelectionContext as well as the columns present
+ in the row (which will be the same columns present in all rows) are used to determine
+ the behavior of the returned callables. The callables will then be used to process
+ all rows and to post-process all instances, respectively.
+
+ callables are of the following form::
+
+ def execute(instance, row, **flags):
+ # process incoming instance and given row.
+ # flags is a dictionary containing at least the following attributes:
+ # isnew - indicates if the instance was newly created as a result of reading this row
+ # instancekey - identity key of the instance
+ # optional attribute:
+ # ispostselect - indicates if this row resulted from a 'post' select of additional tables/columns
+
+ def post_execute(instance, **flags):
+ # process instance after all result rows have been processed. this
+ # function should be used to issue additional selections in order to
+ # eagerly load additional properties.
+
+ return (execute, post_execute)
+
+ either tuple value can also be ``None`` in which case no function is called.
+
"""
-
+
raise NotImplementedError()
-
+
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
+ """return an iterator of objects which are child objects of the given object,
+ as attached to the attribute corresponding to this MapperProperty."""
+
return []
def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
+ """run the given callable across all objects which are child objects of
+ the given object, as attached to the attribute corresponding to this MapperProperty."""
+
return []
def get_criterion(self, query, key, value):
@@ -60,7 +297,11 @@ class MapperProperty(object):
self.do_init()
def do_init(self):
- """Template method for subclasses."""
+ """Perform subclass-specific initialization steps.
+
+ This is a *template* method called by the
+ ``MapperProperty`` object's init() method."""
+
pass
def register_dependencies(self, *args, **kwargs):
@@ -90,59 +331,81 @@ class MapperProperty(object):
raise NotImplementedError()
- def compare(self, value):
+ def compare(self, operator, value):
"""Return a compare operation for the columns represented by
this ``MapperProperty`` to the given value, which may be a
- column value or an instance.
+ column value or an instance. 'operator' is an operator from
+ the operators module, or from sql.Comparator.
+
+ By default uses the PropComparator attached to this MapperProperty
+ under the attribute name "comparator".
"""
- raise NotImplementedError()
+ return operator(self.comparator, value)
-class SynonymProperty(MapperProperty):
- def __init__(self, name, proxy=False):
- self.name = name
- self.proxy = proxy
-
- def setup(self, querycontext, **kwargs):
- pass
-
- def execute(self, selectcontext, instance, row, identitykey, isnew):
- pass
-
- def do_init(self):
- if not self.proxy:
- return
- class SynonymProp(object):
- def __set__(s, obj, value):
- setattr(obj, self.name, value)
- def __delete__(s, obj):
- delattr(obj, self.name)
- def __get__(s, obj, owner):
- if obj is None:
- return s
- return getattr(obj, self.name)
- setattr(self.parent.class_, self.key, SynonymProp())
+class PropComparator(sql.ColumnOperators):
+ """defines comparison operations for MapperProperty objects"""
+
+ def expression_element(self):
+ return self.clause_element()
+
+ def contains_op(a, b):
+ return a.contains(b)
+ contains_op = staticmethod(contains_op)
+
+ def any_op(a, b, **kwargs):
+ return a.any(b, **kwargs)
+ any_op = staticmethod(any_op)
+
+ def has_op(a, b, **kwargs):
+ return a.has(b, **kwargs)
+ has_op = staticmethod(has_op)
+
+ def __init__(self, prop):
+ self.prop = prop
+
+ def contains(self, other):
+ """return true if this collection contains other"""
+ return self.operate(PropComparator.contains_op, other)
+
+ def any(self, criterion=None, **kwargs):
+ """return true if this collection contains any member that meets the given criterion.
+
+ criterion
+ an optional ClauseElement formulated against the member class' table or attributes.
+
+ \**kwargs
+ key/value pairs corresponding to member class attribute names which will be compared
+ via equality to the corresponding values.
+ """
- def merge(self, session, source, dest, _recursive):
- pass
+ return self.operate(PropComparator.any_op, criterion, **kwargs)
+
+ def has(self, criterion=None, **kwargs):
+ """return true if this element references a member which meets the given criterion.
+
+
+ criterion
+ an optional ClauseElement formulated against the member class' table or attributes.
+
+ \**kwargs
+ key/value pairs corresponding to member class attribute names which will be compared
+ via equality to the corresponding values.
+ """
+ return self.operate(PropComparator.has_op, criterion, **kwargs)
+
class StrategizedProperty(MapperProperty):
"""A MapperProperty which uses selectable strategies to affect
loading behavior.
- There is a single default strategy selected, and alternate
- strategies can be selected at selection time through the usage of
- ``StrategizedOption`` objects.
+ There is a single default strategy selected by default. Alternate
+ strategies can be selected at Query time through the usage of
+ ``StrategizedOption`` objects via the Query.options() method.
"""
def _get_context_strategy(self, context):
- try:
- return context.attributes[id(self)]
- except KeyError:
- # cache the located strategy per StrategizedProperty in the given context for faster re-lookup
- ctx_strategy = self._get_strategy(context.attributes.get((LoaderStrategy, self), self.strategy.__class__))
- context.attributes[id(self)] = ctx_strategy
- return ctx_strategy
+ return self._get_strategy(context.attributes.get(("loaderstrategy", self), self.strategy.__class__))
def _get_strategy(self, cls):
try:
@@ -156,11 +419,10 @@ class StrategizedProperty(MapperProperty):
return strategy
def setup(self, querycontext, **kwargs):
-
self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs)
- def execute(self, selectcontext, instance, row, identitykey, isnew):
- self._get_context_strategy(selectcontext).process_row(selectcontext, instance, row, identitykey, isnew)
+ def create_row_processor(self, selectcontext, mapper, row):
+ return self._get_context_strategy(selectcontext).create_row_processor(selectcontext, mapper, row)
def do_init(self):
self._all_strategies = {}
@@ -170,6 +432,31 @@ class StrategizedProperty(MapperProperty):
if self.is_primary():
self.strategy.init_class_attribute()
+class LoaderStack(object):
+ """a stack object used during load operations to track the
+ current position among a chain of mappers to eager loaders."""
+
+ def __init__(self):
+ self.__stack = []
+
+ def push_property(self, key):
+ self.__stack.append(key)
+
+ def push_mapper(self, mapper):
+ self.__stack.append(mapper.base_mapper())
+
+ def pop(self):
+ self.__stack.pop()
+
+ def snapshot(self):
+ """return an 'snapshot' of this stack.
+
+ this is a tuple form of the stack which can be used as a hash key."""
+ return tuple(self.__stack)
+
+ def __str__(self):
+ return "->".join([str(s) for s in self.__stack])
+
class OperationContext(object):
"""Serve as a context during a query construction or instance
loading operation.
@@ -200,6 +487,44 @@ class MapperOption(object):
def process_query(self, query):
pass
+class ExtensionOption(MapperOption):
+ """a MapperOption that applies a MapperExtension to a query operation."""
+
+ def __init__(self, ext):
+ self.ext = ext
+
+ def process_query(self, query):
+ query._extension = query._extension.copy()
+ query._extension.append(self.ext)
+
+class SynonymProperty(MapperProperty):
+ def __init__(self, name, proxy=False):
+ self.name = name
+ self.proxy = proxy
+
+ def setup(self, querycontext, **kwargs):
+ pass
+
+ def create_row_processor(self, selectcontext, mapper, row):
+ return (None, None)
+
+ def do_init(self):
+ if not self.proxy:
+ return
+ class SynonymProp(object):
+ def __set__(s, obj, value):
+ setattr(obj, self.name, value)
+ def __delete__(s, obj):
+ delattr(obj, self.name)
+ def __get__(s, obj, owner):
+ if obj is None:
+ return s
+ return getattr(obj, self.name)
+ setattr(self.parent.class_, self.key, SynonymProp())
+
+ def merge(self, session, source, dest, _recursive):
+ pass
+
class PropertyOption(MapperOption):
"""A MapperOption that is applied to a property off the mapper or
one of its child mappers, identified by a dot-separated key.
@@ -208,45 +533,72 @@ class PropertyOption(MapperOption):
def __init__(self, key):
self.key = key
- def process_query_property(self, context, property):
+ def process_query_property(self, context, properties):
pass
- def process_selection_property(self, context, property):
+ def process_selection_property(self, context, properties):
pass
def process_query_context(self, context):
- self.process_query_property(context, self._get_property(context))
+ self.process_query_property(context, self._get_properties(context))
def process_selection_context(self, context):
- self.process_selection_property(context, self._get_property(context))
+ self.process_selection_property(context, self._get_properties(context))
- def _get_property(self, context):
+ def _get_properties(self, context):
try:
- prop = self.__prop
+ l = self.__prop
except AttributeError:
+ l = []
mapper = context.mapper
for token in self.key.split('.'):
- prop = mapper.props[token]
- if isinstance(prop, SynonymProperty):
- prop = mapper.props[prop.name]
+ prop = mapper.get_property(token, resolve_synonyms=True)
+ l.append(prop)
mapper = getattr(prop, 'mapper', None)
- self.__prop = prop
- return prop
+ self.__prop = l
+ return l
PropertyOption.logger = logging.class_logger(PropertyOption)
+
+class AttributeExtension(object):
+ """An abstract class which specifies `append`, `delete`, and `set`
+ event handlers to be attached to an object property.
+ """
+
+ def append(self, obj, child, initiator):
+ pass
+
+ def remove(self, obj, child, initiator):
+ pass
+
+ def set(self, obj, child, oldchild, initiator):
+ pass
+
+
class StrategizedOption(PropertyOption):
"""A MapperOption that affects which LoaderStrategy will be used
for an operation by a StrategizedProperty.
"""
- def process_query_property(self, context, property):
+ def is_chained(self):
+ return False
+
+ def process_query_property(self, context, properties):
self.logger.debug("applying option to QueryContext, property key '%s'" % self.key)
- context.attributes[(LoaderStrategy, property)] = self.get_strategy_class()
+ if self.is_chained():
+ for prop in properties:
+ context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
+ else:
+ context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
- def process_selection_property(self, context, property):
+ def process_selection_property(self, context, properties):
self.logger.debug("applying option to SelectionContext, property key '%s'" % self.key)
- context.attributes[(LoaderStrategy, property)] = self.get_strategy_class()
+ if self.is_chained():
+ for prop in properties:
+ context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
+ else:
+ context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
def get_strategy_class(self):
raise NotImplementedError()
@@ -291,5 +643,13 @@ class LoaderStrategy(object):
def setup_query(self, context, **kwargs):
pass
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- pass
+ def create_row_processor(self, selectcontext, mapper, row):
+ """return row processing functions which fulfill the contract specified
+ by MapperProperty.create_row_processor.
+
+
+ StrategizedProperty delegates its create_row_processor method
+ directly to this method.
+ """
+
+ raise NotImplementedError()
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 375408926..76cc41289 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -4,14 +4,15 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, schema, util, exceptions, logging
+from sqlalchemy import sql, util, exceptions, logging
from sqlalchemy import sql_util as sqlutil
from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.util import ExtensionCarrier
from sqlalchemy.orm import sync
-from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, SynonymProperty
-import weakref, warnings
+from sqlalchemy.orm.interfaces import MapperProperty, EXT_PASS, MapperExtension, SynonymProperty
+import weakref, warnings, operator
-__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS', 'mapper_registry', 'ExtensionOption']
+__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
# a dictionary mapping classes to their primary mappers
mapper_registry = weakref.WeakKeyDictionary()
@@ -24,12 +25,13 @@ global_extensions = []
# column
NO_ATTRIBUTE = object()
-# returned by a MapperExtension method to indicate a "do nothing" response
-EXT_PASS = object()
-
# lock used to synchronize the "mapper compile" step
_COMPILE_MUTEX = util.threading.Lock()
+# initialize these two lazily
+attribute_manager = None
+ColumnProperty = None
+
class Mapper(object):
"""Define the correlation of class attributes to database table
columns.
@@ -55,6 +57,7 @@ class Mapper(object):
polymorphic_on=None,
_polymorphic_map=None,
polymorphic_identity=None,
+ polymorphic_fetch=None,
concrete=False,
select_table=None,
allow_null_pks=False,
@@ -62,126 +65,8 @@ class Mapper(object):
column_prefix=None):
"""Construct a new mapper.
- All arguments may be sent to the ``sqlalchemy.orm.mapper()``
- function where they are passed through to here.
-
- class\_
- The class to be mapped.
-
- local_table
- The table to which the class is mapped, or None if this
- mapper inherits from another mapper using concrete table
- inheritance.
-
- properties
- A dictionary mapping the string names of object attributes
- to ``MapperProperty`` instances, which define the
- persistence behavior of that attribute. Note that the
- columns in the mapped table are automatically converted into
- ``ColumnProperty`` instances based on the `key` property of
- each ``Column`` (although they can be overridden using this
- dictionary).
-
- primary_key
- A list of ``Column`` objects which define the *primary key*
- to be used against this mapper's selectable unit. This is
- normally simply the primary key of the `local_table`, but
- can be overridden here.
-
- non_primary
- Construct a ``Mapper`` that will define only the selection
- of instances, not their persistence.
-
- inherits
- Another ``Mapper`` for which this ``Mapper`` will have an
- inheritance relationship with.
-
- inherit_condition
- For joined table inheritance, a SQL expression (constructed
- ``ClauseElement``) which will define how the two tables are
- joined; defaults to a natural join between the two tables.
-
- extension
- A ``MapperExtension`` instance or list of
- ``MapperExtension`` instances which will be applied to all
- operations by this ``Mapper``.
-
- order_by
- A single ``Column`` or list of ``Columns`` for which
- selection operations should use as the default ordering for
- entities. Defaults to the OID/ROWID of the table if any, or
- the first primary key column of the table.
-
- allow_column_override
- If True, allows the usage of a ``relation()`` which has the
- same name as a column in the mapped table. The table column
- will no longer be mapped.
-
- entity_name
- A name to be associated with the `class`, to allow alternate
- mappings for a single class.
-
- always_refresh
- If True, all query operations for this mapped class will
- overwrite all data within object instances that already
- exist within the session, erasing any in-memory changes with
- whatever information was loaded from the database.
-
- version_id_col
- A ``Column`` which must have an integer type that will be
- used to keep a running *version id* of mapped entities in
- the database. this is used during save operations to ensure
- that no other thread or process has updated the instance
- during the lifetime of the entity, else a
- ``ConcurrentModificationError`` exception is thrown.
-
- polymorphic_on
- Used with mappers in an inheritance relationship, a ``Column``
- which will identify the class/mapper combination to be used
- with a particular row. requires the polymorphic_identity
- value to be set for all mappers in the inheritance
- hierarchy.
-
- _polymorphic_map
- Used internally to propigate the full map of polymorphic
- identifiers to surrogate mappers.
-
- polymorphic_identity
- A value which will be stored in the Column denoted by
- polymorphic_on, corresponding to the *class identity* of
- this mapper.
-
- concrete
- If True, indicates this mapper should use concrete table
- inheritance with its parent mapper.
-
- select_table
- A ``Table`` or (more commonly) ``Selectable`` which will be
- used to select instances of this mapper's class. usually
- used to provide polymorphic loading among several classes in
- an inheritance hierarchy.
-
- allow_null_pks
- Indicates that composite primary keys where one or more (but
- not all) columns contain NULL is a valid primary key.
- Primary keys which contain NULL values usually indicate that
- a result row does not contain an entity and should be
- skipped.
-
- batch
- Indicates that save operations of multiple entities can be
- batched together for efficiency. setting to False indicates
- that an instance will be fully saved before saving the next
- instance, which includes inserting/updating all table rows
- corresponding to the entity as well as calling all
- ``MapperExtension`` methods corresponding to the save
- operation.
-
- column_prefix
- A string which will be prepended to the `key` name of all
- Columns when creating column-based properties from the given
- Table. Does not affect explicitly specified column-based
- properties
+ Mappers are normally constructed via the [sqlalchemy.orm#mapper()]
+ function. See for details.
"""
if not issubclass(class_, object):
@@ -227,6 +112,13 @@ class Mapper(object):
# indicates this Mapper should be used to construct the object instance for that row.
self.polymorphic_identity = polymorphic_identity
+ if polymorphic_fetch not in (None, 'union', 'select', 'deferred'):
+ raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch)
+ if polymorphic_fetch is None:
+ self.polymorphic_fetch = (self.select_table is None) and 'select' or 'union'
+ else:
+ self.polymorphic_fetch = polymorphic_fetch
+
# a dictionary of 'polymorphic identity' names, associating those names with
# Mappers that will be used to construct object instances upon a select operation.
if _polymorphic_map is None:
@@ -297,20 +189,8 @@ class Mapper(object):
else:
return False
- def _get_props(self):
- self.compile()
- return self.__props
-
- props = property(_get_props, doc="compiles this mapper if needed, and returns the "
- "dictionary of MapperProperty objects associated with this mapper."
- "(Deprecated; use get_property() and iterate_properties)")
-
def get_property(self, key, resolve_synonyms=False, raiseerr=True):
- """return MapperProperty with the given key.
-
- forwards compatible with 0.4.
- """
-
+ """return MapperProperty with the given key."""
self.compile()
prop = self.__props.get(key, None)
if resolve_synonyms:
@@ -319,10 +199,22 @@ class Mapper(object):
if prop is None and raiseerr:
raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key))
return prop
-
- iterate_properties = property(lambda self: self._get_props().itervalues(), doc="returns an iterator of all MapperProperty objects."
- " Forwards compatible with 0.4")
-
+
+ def iterate_properties(self):
+ self.compile()
+ return self.__props.itervalues()
+ iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.")
+
+ def dispose(self):
+ attribute_manager.reset_class_managed(self.class_)
+ if hasattr(self.class_, 'c'):
+ del self.class_.c
+ if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'):
+ if self.class_.__init__._oldinit is not None:
+ self.class_.__init__ = self.class_.__init__._oldinit
+ else:
+ delattr(self.class_, '__init__')
+
def compile(self):
"""Compile this mapper into its final internal format.
@@ -403,7 +295,7 @@ class Mapper(object):
for ext_obj in util.to_list(extension):
extlist.add(ext_obj)
- self.extension = _ExtensionCarrier()
+ self.extension = ExtensionCarrier()
for ext in extlist:
self.extension.append(ext)
@@ -452,8 +344,17 @@ class Mapper(object):
self.mapped_table = self.local_table
if self.polymorphic_identity is not None:
self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self)
- if self.polymorphic_on is None and self.inherits.polymorphic_on is not None:
- self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False)
+ if self.polymorphic_on is None:
+ if self.inherits.polymorphic_on is not None:
+ self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False)
+ else:
+ raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
+
+ if self.polymorphic_identity is not None and not self.concrete:
+ self._identity_class = self.inherits._identity_class
+ else:
+ self._identity_class = self.class_
+
if self.order_by is False:
self.order_by = self.inherits.order_by
self.polymorphic_map = self.inherits.polymorphic_map
@@ -463,8 +364,11 @@ class Mapper(object):
self._synchronizer = None
self.mapped_table = self.local_table
if self.polymorphic_identity is not None:
+ if self.polymorphic_on is None:
+ raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
self._add_polymorphic_mapping(self.polymorphic_identity, self)
-
+ self._identity_class = self.class_
+
if self.mapped_table is None:
raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self))
@@ -503,39 +407,134 @@ class Mapper(object):
# may be a join or other construct
self.tables = sqlutil.TableFinder(self.mapped_table)
- # determine primary key columns, either passed in, or get them from our set of tables
+ # determine primary key columns
self.pks_by_table = {}
+
+ # go through all of our represented tables
+ # and assemble primary key columns
+ for t in self.tables + [self.mapped_table]:
+ try:
+ l = self.pks_by_table[t]
+ except KeyError:
+ l = self.pks_by_table.setdefault(t, util.OrderedSet())
+ for k in t.primary_key:
+ l.add(k)
+
if self.primary_key_argument is not None:
- # determine primary keys using user-given list of primary key columns as a guide
- #
- # TODO: this might not work very well for joined-table and/or polymorphic
- # inheritance mappers since local_table isnt taken into account nor is select_table
- # need to test custom primary key columns used with inheriting mappers
for k in self.primary_key_argument:
self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
- if k.table != self.mapped_table:
- # associate pk cols from subtables to the "main" table
- corr = self.mapped_table.corresponding_column(k, raiseerr=False)
- if corr is not None:
- self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(corr)
- else:
- # no user-defined primary key columns - go through all of our represented tables
- # and assemble primary key columns
- for t in self.tables + [self.mapped_table]:
- try:
- l = self.pks_by_table[t]
- except KeyError:
- l = self.pks_by_table.setdefault(t, util.OrderedSet())
- for k in t.primary_key:
- #if k.key not in t.c and k._label not in t.c:
- # this is a condition that was occurring when table reflection was doubling up primary keys
- # that were overridden in the Table constructor
- # raise exceptions.AssertionError("Column " + str(k) + " not located in the column set of table " + str(t))
- l.add(k)
-
+
if len(self.pks_by_table[self.mapped_table]) == 0:
raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
- self.primary_key = self.pks_by_table[self.mapped_table]
+
+ if self.inherits is not None and not self.concrete and not self.primary_key_argument:
+ self.primary_key = self.inherits.primary_key
+ self._get_clause = self.inherits._get_clause
+ else:
+ # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns
+ # into one column, where "equivalent" means that one column references the other via foreign key, or
+ # multiple columns that all reference a common parent column. it will also resolve the column
+ # against the "mapped_table" of this mapper.
+ equivalent_columns = self._get_equivalent_columns()
+
+ primary_key = sql.ColumnSet()
+
+ for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ c = self.mapped_table.corresponding_column(col, raiseerr=False)
+ if c is None:
+ for cc in equivalent_columns[col]:
+ c = self.mapped_table.corresponding_column(cc, raiseerr=False)
+ if c is not None:
+ break
+ else:
+ raise exceptions.ArgumentError("Cant resolve column " + str(col))
+
+ # this step attempts to resolve the column to an equivalent which is not
+ # a foreign key elsewhere. this helps with joined table inheritance
+ # so that PKs are expressed in terms of the base table which is always
+ # present in the initial select
+ # TODO: this is a little hacky right now, the "tried" list is to prevent
+ # endless loops between cyclical FKs, try to make this cleaner/work better/etc.,
+ # perhaps via topological sort (pick the leftmost item)
+ tried = util.Set()
+ while True:
+ if not len(c.foreign_keys) or c in tried:
+ break
+ for cc in c.foreign_keys:
+ cc = cc.column
+ c2 = self.mapped_table.corresponding_column(cc, raiseerr=False)
+ if c2 is not None:
+ c = c2
+ tried.add(c)
+ break
+ else:
+ break
+ primary_key.add(c)
+
+ if len(primary_key) == 0:
+ raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
+
+ self.primary_key = primary_key
+ self.__log("Identified primary key columns: " + str(primary_key))
+
+ _get_clause = sql.and_()
+ for primary_key in self.primary_key:
+ _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True))
+ self._get_clause = _get_clause
+
+ def _get_equivalent_columns(self):
+ """Create a map of all *equivalent* columns, based on
+ the determination of column pairs that are equated to
+ one another either by an established foreign key relationship
+ or by a joined-table inheritance join.
+
+ This is used to determine the minimal set of primary key
+ columns for the mapper, as well as when relating
+ columns to those of a polymorphic selectable (i.e. a UNION of
+ several mapped tables), as that selectable usually only contains
+ one column in its columns clause out of a group of several which
+ are equated to each other.
+
+ The resulting structure is a dictionary of columns mapped
+ to lists of equivalent columns, i.e.
+
+ {
+ tablea.col1:
+ set([tableb.col1, tablec.col1]),
+ tablea.col2:
+ set([tabled.col2])
+ }
+
+ this method is called repeatedly during the compilation process as
+ the resulting dictionary contains more equivalents as more inheriting
+ mappers are compiled. the repetition process may be open to some optimization.
+ """
+
+ result = {}
+ def visit_binary(binary):
+ if binary.operator == operator.eq:
+ if binary.left in result:
+ result[binary.left].add(binary.right)
+ else:
+ result[binary.left] = util.Set([binary.right])
+ if binary.right in result:
+ result[binary.right].add(binary.left)
+ else:
+ result[binary.right] = util.Set([binary.left])
+ vis = mapperutil.BinaryVisitor(visit_binary)
+
+ for mapper in self.base_mapper().polymorphic_iterator():
+ if mapper.inherit_condition is not None:
+ vis.traverse(mapper.inherit_condition)
+
+ for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ if not len(col.foreign_keys):
+ result.setdefault(col, util.Set()).add(col)
+ else:
+ for fk in col.foreign_keys:
+ result.setdefault(fk.column, util.Set()).add(col)
+
+ return result
def _compile_properties(self):
"""Inspect the properties dictionary sent to the Mapper's
@@ -548,7 +547,7 @@ class Mapper(object):
"""
# object attribute names mapped to MapperProperty objects
- self.__props = {}
+ self.__props = util.OrderedDict()
# table columns mapped to lists of MapperProperty objects
# using a list allows a single column to be defined as
@@ -574,7 +573,7 @@ class Mapper(object):
self.columns[column.key] = self.select_table.corresponding_column(column, keys_ok=True, raiseerr=True)
column_key = (self.column_prefix or '') + column.key
- prop = self.__props.get(column_key, None)
+ prop = self.__props.get(column.key, None)
if prop is None:
prop = ColumnProperty(column)
self.__props[column_key] = prop
@@ -582,7 +581,7 @@ class Mapper(object):
self.__log("adding ColumnProperty %s" % (column_key))
elif isinstance(prop, ColumnProperty):
if prop.parent is not self:
- prop = ColumnProperty(deferred=prop.deferred, group=prop.group, *prop.columns)
+ prop = prop.copy()
prop.set_parent(self)
self.__props[column_key] = prop
if column in self.primary_key and prop.columns[-1] in self.primary_key:
@@ -597,8 +596,7 @@ class Mapper(object):
# its a ColumnProperty - match the ultimate table columns
# back to the property
- proplist = self.columntoproperty.setdefault(column, [])
- proplist.append(prop)
+ self.columntoproperty.setdefault(column, []).append(prop)
def _initialize_properties(self):
@@ -660,66 +658,43 @@ class Mapper(object):
attribute_manager.reset_class_managed(self.class_)
oldinit = self.class_.__init__
- def init(self, *args, **kwargs):
- entity_name = kwargs.pop('_sa_entity_name', None)
- mapper = mapper_registry.get(ClassKey(self.__class__, entity_name))
- if mapper is not None:
- mapper = mapper.compile()
-
- # this gets the AttributeManager to do some pre-initialization,
- # in order to save on KeyErrors later on
- attribute_manager.init_attr(self)
-
- if kwargs.has_key('_sa_session'):
- session = kwargs.pop('_sa_session')
- else:
- # works for whatever mapper the class is associated with
- if mapper is not None:
- session = mapper.extension.get_session()
- if session is EXT_PASS:
- session = None
- else:
- session = None
- # if a session was found, either via _sa_session or via mapper extension,
- # and we have found a mapper, save() this instance to the session, and give it an associated entity_name.
- # otherwise, this instance will not have a session or mapper association until it is
- # save()d to some session.
- if session is not None and mapper is not None:
- self._entity_name = entity_name
- session._register_pending(self)
+ def init(instance, *args, **kwargs):
+ self.compile()
+ self.extension.init_instance(self, self.class_, instance, args, kwargs)
if oldinit is not None:
try:
- oldinit(self, *args, **kwargs)
+ oldinit(instance, *args, **kwargs)
except:
- def go():
- if session is not None:
- session.expunge(self)
- # convert expunge() exceptions to warnings
- util.warn_exception(go)
+ # call init_failed but suppress exceptions into warnings so that original __init__
+ # exception is raised
+ util.warn_exception(self.extension.init_failed, self, self.class_, instance, args, kwargs)
raise
-
- # override oldinit, insuring that its not already a Mapper-decorated init method
- if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'):
- init._sa_mapper_init = True
+
+ # override oldinit, ensuring that its not already a Mapper-decorated init method
+ if oldinit is None or not hasattr(oldinit, '_oldinit'):
try:
init.__name__ = oldinit.__name__
init.__doc__ = oldinit.__doc__
except:
# cant set __name__ in py 2.3 !
pass
+ init._oldinit = oldinit
self.class_.__init__ = init
+
_COMPILE_MUTEX.acquire()
try:
mapper_registry[self.class_key] = self
finally:
_COMPILE_MUTEX.release()
+
if self.entity_name is None:
self.class_.c = self.c
def base_mapper(self):
"""Return the ultimate base mapper in an inheritance chain."""
+ # TODO: calculate this at mapper setup time
if self.inherits is not None:
return self.inherits.base_mapper()
else:
@@ -759,43 +734,6 @@ class Mapper(object):
for m in mapper.polymorphic_iterator():
yield m
- def _get_inherited_column_equivalents(self):
- """Return a map of all *equivalent* columns, based on
- traversing the full set of inherit_conditions across all
- inheriting mappers and determining column pairs that are
- equated to one another.
-
- This is used when relating columns to those of a polymorphic
- selectable, as the selectable usually only contains one of two (or more)
- columns that are equated to one another.
-
- The resulting structure is a dictionary of columns mapped
- to lists of equivalent columns, i.e.
-
- {
- tablea.col1:
- [tableb.col1, tablec.col1],
- tablea.col2:
- [tabled.col2]
- }
- """
-
- result = {}
- def visit_binary(binary):
- if binary.operator == '=':
- if binary.left in result:
- result[binary.left].append(binary.right)
- else:
- result[binary.left] = [binary.right]
- if binary.right in result:
- result[binary.right].append(binary.left)
- else:
- result[binary.right] = [binary.left]
- vis = mapperutil.BinaryVisitor(visit_binary)
- for mapper in self.base_mapper().polymorphic_iterator():
- if mapper.inherit_condition is not None:
- vis.traverse(mapper.inherit_condition)
- return result
def add_properties(self, dict_of_properties):
"""Add the given dictionary of properties to this mapper,
@@ -947,7 +885,7 @@ class Mapper(object):
dictionary corresponding result-set ``ColumnElement``
instances to their values within a row.
"""
- return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name)
+ return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name)
def identity_key_from_primary_key(self, primary_key):
"""Return an identity-map key for use in storing/retrieving an
@@ -956,7 +894,7 @@ class Mapper(object):
primary_key
A list of values indicating the identifier.
"""
- return (self.class_, tuple(util.to_list(primary_key)), self.entity_name)
+ return (self._identity_class, tuple(util.to_list(primary_key)), self.entity_name)
def identity_key_from_instance(self, instance):
"""Return the identity key for the given instance, based on
@@ -972,7 +910,7 @@ class Mapper(object):
instance.
"""
- return [self.get_attr_by_column(instance, column) for column in self.pks_by_table[self.mapped_table]]
+ return [self.get_attr_by_column(instance, column) for column in self.primary_key]
def canload(self, instance):
"""return true if this mapper is capable of loading the given instance"""
@@ -981,21 +919,6 @@ class Mapper(object):
else:
return instance.__class__ is self.class_
- def instance_key(self, instance):
- """Deprecated. A synonym for `identity_key_from_instance`."""
-
- return self.identity_key_from_instance(instance)
-
- def identity_key(self, primary_key):
- """Deprecated. A synonym for `identity_key_from_primary_key`."""
-
- return self.identity_key_from_primary_key(primary_key)
-
- def identity(self, instance):
- """Deprecated. A synoynm for `primary_key_from_instance`."""
-
- return self.primary_key_from_instance(instance)
-
def _getpropbycolumn(self, column, raiseerror=True):
try:
prop = self.columntoproperty[column]
@@ -1017,13 +940,12 @@ class Mapper(object):
prop = self._getpropbycolumn(column, raiseerror)
if prop is None:
return NO_ATTRIBUTE
- #print "get column attribute '%s' from instance %s" % (column.key, mapperutil.instance_str(obj))
- return prop.getattr(obj)
+ return prop.getattr(obj, column)
def set_attr_by_column(self, obj, column, value):
"""Set the value of an instance attribute using a Column as the key."""
- self.columntoproperty[column][0].setattr(obj, value)
+ self.columntoproperty[column][0].setattr(obj, value, column)
def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
@@ -1048,10 +970,15 @@ class Mapper(object):
self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
return
- connection = uowtransaction.transaction.connection(self)
-
+ if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+ tups = [(obj, connection_callable(self, obj)) for obj in objects]
+ else:
+ connection = uowtransaction.transaction.connection(self)
+ tups = [(obj, connection) for obj in objects]
+
if not postupdate:
- for obj in objects:
+ for obj, connection in tups:
if not has_identity(obj):
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_insert(mapper, connection, obj)
@@ -1059,12 +986,12 @@ class Mapper(object):
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_update(mapper, connection, obj)
- for obj in objects:
+ for obj, connection in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
# and another instance with the same identity key already exists as persistent. convert to an
# UPDATE if so.
mapper = object_mapper(obj)
- instance_key = mapper.instance_key(obj)
+ instance_key = mapper.identity_key_from_instance(obj)
is_row_switch = not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map
if is_row_switch:
existing = uowtransaction.uow.identity_map[instance_key]
@@ -1090,11 +1017,11 @@ class Mapper(object):
insert = []
update = []
- for obj in objects:
+ for obj, connection in tups:
mapper = object_mapper(obj)
if table not in mapper.tables or not mapper._has_pks(table):
continue
- instance_key = mapper.instance_key(obj)
+ instance_key = mapper.identity_key_from_instance(obj)
if self.__should_log_debug:
self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key)))
@@ -1149,7 +1076,7 @@ class Mapper(object):
if history:
a = history.added_items()
if len(a):
- params[col.key] = a[0]
+ params[col.key] = prop.get_col_value(col, a[0])
hasdata = True
else:
# doing an INSERT, non primary key col ?
@@ -1168,17 +1095,17 @@ class Mapper(object):
if hasdata:
# if none of the attributes changed, dont even
# add the row to be updated.
- update.append((obj, params, mapper))
+ update.append((obj, params, mapper, connection))
else:
- insert.append((obj, params, mapper))
+ insert.append((obj, params, mapper, connection))
if len(update):
mapper = table_to_mapper[table]
clause = sql.and_()
for col in mapper.pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True))
+ clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
if mapper.version_id_col is not None:
- clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True))
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
statement = table.update(clause)
rows = 0
supports_sane_rowcount = True
@@ -1190,11 +1117,11 @@ class Mapper(object):
return 0
update.sort(comparator)
for rec in update:
- (obj, params, mapper) = rec
+ (obj, params, mapper, connection) = rec
c = connection.execute(statement, params)
mapper._postfetch(connection, table, obj, c, c.last_updated_params())
- updated_objects.add(obj)
+ updated_objects.add((obj, connection))
rows += c.rowcount
if c.supports_sane_rowcount() and rows != len(update):
@@ -1206,7 +1133,7 @@ class Mapper(object):
return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order)
insert.sort(comparator)
for rec in insert:
- (obj, params, mapper) = rec
+ (obj, params, mapper, connection) = rec
c = connection.execute(statement, params)
primary_key = c.last_inserted_ids()
if primary_key is not None:
@@ -1228,12 +1155,12 @@ class Mapper(object):
mapper._synchronizer.execute(obj, obj)
sync(mapper)
- inserted_objects.add(obj)
+ inserted_objects.add((obj, connection))
if not postupdate:
- for obj in inserted_objects:
+ for obj, connection in inserted_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_insert(mapper, connection, obj)
- for obj in updated_objects:
+ for obj, connection in updated_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_update(mapper, connection, obj)
@@ -1273,9 +1200,14 @@ class Mapper(object):
if self.__should_log_debug:
self.__log_debug("delete_obj() start")
- connection = uowtransaction.transaction.connection(self)
+ if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+ tups = [(obj, connection_callable(self, obj)) for obj in objects]
+ else:
+ connection = uowtransaction.transaction.connection(self)
+ tups = [(obj, connection) for obj in objects]
- for obj in objects:
+ for (obj, connection) in tups:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_delete(mapper, connection, obj)
@@ -1286,8 +1218,8 @@ class Mapper(object):
table_to_mapper.setdefault(t, mapper)
for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True):
- delete = []
- for obj in objects:
+ delete = {}
+ for (obj, connection) in tups:
mapper = object_mapper(obj)
if table not in mapper.tables or not mapper._has_pks(table):
continue
@@ -1296,13 +1228,13 @@ class Mapper(object):
if not hasattr(obj, '_instance_key'):
continue
else:
- delete.append(params)
+ delete.setdefault(connection, []).append(params)
for col in mapper.pks_by_table[table]:
params[col.key] = mapper.get_attr_by_column(obj, col)
if mapper.version_id_col is not None:
params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
- deleted_objects.add(obj)
- if len(delete):
+ deleted_objects.add((obj, connection))
+ for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
def comparator(a, b):
for col in mapper.pks_by_table[table]:
@@ -1310,18 +1242,18 @@ class Mapper(object):
if x != 0:
return x
return 0
- delete.sort(comparator)
+ del_objects.sort(comparator)
clause = sql.and_()
for col in mapper.pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True))
+ clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
if mapper.version_id_col is not None:
- clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type=mapper.version_id_col.type, unique=True))
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
statement = table.delete(clause)
- c = connection.execute(statement, delete)
- if c.supports_sane_rowcount() and c.rowcount != len(delete):
+ c = connection.execute(statement, del_objects)
+ if c.supports_sane_rowcount() and c.rowcount != len(del_objects):
raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete)))
- for obj in deleted_objects:
+ for obj, connection in deleted_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_delete(mapper, connection, obj)
@@ -1429,15 +1361,17 @@ class Mapper(object):
if discriminator is not None:
mapper = self.polymorphic_map[discriminator]
if mapper is not self:
+ if ('polymorphic_fetch', mapper) not in context.attributes:
+ context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
row = self.translate_row(mapper, row)
return mapper._instance(context, row, result=result, skip_polymorphic=True)
-
+
# look in main identity map. if its there, we dont do anything to it,
# including modifying any of its related items lists, as its already
# been exposed to being modified by the application.
- populate_existing = context.populate_existing or self.always_refresh
identitykey = self.identity_key_from_row(row)
+ populate_existing = context.populate_existing or self.always_refresh
if context.session.has_key(identitykey):
instance = context.session._get(identitykey)
if self.__should_log_debug:
@@ -1450,32 +1384,31 @@ class Mapper(object):
if not context.identity_map.has_key(identitykey):
context.identity_map[identitykey] = instance
isnew = True
- if extension.populate_instance(self, context, row, instance, identitykey, isnew) is EXT_PASS:
- self.populate_instance(context, instance, row, identitykey, isnew)
- if extension.append_result(self, context, row, instance, identitykey, result, isnew) is EXT_PASS:
+ if extension.populate_instance(self, context, row, instance, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS:
+ self.populate_instance(context, instance, row, **{'instancekey':identitykey, 'isnew':isnew})
+ if extension.append_result(self, context, row, instance, result, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS:
if result is not None:
result.append(instance)
return instance
else:
if self.__should_log_debug:
- self.__log_debug("_instance(): identity key %s not in session" % str(identitykey) + repr([mapperutil.instance_str(x) for x in context.session]))
+ self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
# look in result-local identitymap for it.
- exists = context.identity_map.has_key(identitykey)
+ exists = identitykey in context.identity_map
if not exists:
if self.allow_null_pks:
# check if *all* primary key cols in the result are None - this indicates
# an instance of the object is not present in the row.
- for col in self.pks_by_table[self.mapped_table]:
- if row[col] is not None:
+ for x in identitykey[1]:
+ if x is not None:
break
else:
return None
else:
# otherwise, check if *any* primary key cols in the result are None - this indicates
# an instance of the object is not present in the row.
- for col in self.pks_by_table[self.mapped_table]:
- if row[col] is None:
- return None
+ if None in identitykey[1]:
+ return None
# plugin point
instance = extension.create_instance(self, context, row, self.class_)
@@ -1493,9 +1426,10 @@ class Mapper(object):
# call further mapper properties on the row, to pull further
# instances from the row and possibly populate this item.
- if extension.populate_instance(self, context, row, instance, identitykey, isnew) is EXT_PASS:
- self.populate_instance(context, instance, row, identitykey, isnew)
- if extension.append_result(self, context, row, instance, identitykey, result, isnew) is EXT_PASS:
+ flags = {'instancekey':identitykey, 'isnew':isnew}
+ if extension.populate_instance(self, context, row, instance, **flags) is EXT_PASS:
+ self.populate_instance(context, instance, row, **flags)
+ if extension.append_result(self, context, row, instance, result, **flags) is EXT_PASS:
if result is not None:
result.append(instance)
return instance
@@ -1510,6 +1444,24 @@ class Mapper(object):
return obj
+ def _deferred_inheritance_condition(self, needs_tables):
+ cond = self.inherit_condition
+
+ param_names = []
+ def visit_binary(binary):
+ leftcol = binary.left
+ rightcol = binary.right
+ if leftcol is None or rightcol is None:
+ return
+ if leftcol.table not in needs_tables:
+ binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True)
+ param_names.append(leftcol)
+ elif rightcol not in needs_tables:
+ binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True)
+ param_names.append(rightcol)
+ cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
+ return cond, param_names
+
def translate_row(self, tomapper, row):
"""Translate the column keys of a row into a new or proxied
row that can be understood by another mapper.
@@ -1520,288 +1472,71 @@ class Mapper(object):
newrow = util.DictDecorator(row)
for c in tomapper.mapped_table.c:
- c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True)
- if row.has_key(c2):
+ c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=False)
+ if c2 and row.has_key(c2):
newrow[c] = row[c2]
return newrow
- def populate_instance(self, selectcontext, instance, row, identitykey, isnew):
- """populate an instance from a result row.
-
- This method iterates through the list of MapperProperty objects attached to this Mapper
- and calls each properties execute() method."""
- for prop in self.__props.values():
- prop.execute(selectcontext, instance, row, identitykey, isnew)
-
-Mapper.logger = logging.class_logger(Mapper)
-
-
-class MapperExtension(object):
- """Base implementation for an object that provides overriding
- behavior to various Mapper functions. For each method in
- MapperExtension, a result of EXT_PASS indicates the functionality
- is not overridden.
- """
-
- def get_session(self):
- """Retrieve a contextual Session instance with which to
- register a new object.
-
- Note: this is not called if a session is provided with the
- `__init__` params (i.e. `_sa_session`).
- """
-
- return EXT_PASS
-
- def load(self, query, *args, **kwargs):
- """Override the `load` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.load()`` if the value is anything other than EXT_PASS.
- """
-
- return EXT_PASS
-
- def get(self, query, *args, **kwargs):
- """Override the `get` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.get()`` if the value is anything other than EXT_PASS.
- """
-
- return EXT_PASS
-
- def get_by(self, query, *args, **kwargs):
- """Override the `get_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.get_by()`` if the value is anything other than
- EXT_PASS.
- """
-
- return EXT_PASS
-
- def select_by(self, query, *args, **kwargs):
- """Override the `select_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select_by()`` if the value is anything other than
- EXT_PASS.
- """
-
- return EXT_PASS
-
- def select(self, query, *args, **kwargs):
- """Override the `select` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select()`` if the value is anything other than
- EXT_PASS.
- """
-
- return EXT_PASS
-
-
- def translate_row(self, mapper, context, row):
- """Perform pre-processing on the given result row and return a
- new row instance.
-
- This is called as the very first step in the ``_instance()``
- method.
- """
-
- return EXT_PASS
-
- def create_instance(self, mapper, selectcontext, row, class_):
- """Receive a row when a new object instance is about to be
- created from that row.
-
- The method can choose to create the instance itself, or it can
- return None to indicate normal object creation should take
- place.
-
- mapper
- The mapper doing the operation
-
- selectcontext
- SelectionContext corresponding to the instances() call
-
- row
- The result row from the database
-
- class\_
- The class we are mapping.
- """
-
- return EXT_PASS
-
- def append_result(self, mapper, selectcontext, row, instance, identitykey, result, isnew):
- """Receive an object instance before that instance is appended
- to a result list.
-
- If this method returns EXT_PASS, result appending will proceed
- normally. if this method returns any other value or None,
- result appending will not proceed for this instance, giving
- this extension an opportunity to do the appending itself, if
- desired.
-
- mapper
- The mapper doing the operation.
-
- selectcontext
- SelectionContext corresponding to the instances() call.
-
- row
- The result row from the database.
-
- instance
- The object instance to be appended to the result.
-
- identitykey
- The identity key of the instance.
-
- result
- List to which results are being appended.
-
- isnew
- Indicates if this is the first time we have seen this object
- instance in the current result set. if you are selecting
- from a join, such as an eager load, you might see the same
- object instance many times in the same result set.
- """
-
- return EXT_PASS
-
- def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
- """Receive a newly-created instance before that instance has
- its attributes populated.
-
- The normal population of attributes is according to each
- attribute's corresponding MapperProperty (which includes
- column-based attributes as well as relationships to other
- classes). If this method returns EXT_PASS, instance
- population will proceed normally. If any other value or None
- is returned, instance population will not proceed, giving this
- extension an opportunity to populate the instance itself, if
- desired.
- """
-
- return EXT_PASS
-
- def before_insert(self, mapper, connection, instance):
- """Receive an object instance before that instance is INSERTed
- into its table.
-
- This is a good place to set up primary key values and such
- that aren't handled otherwise.
- """
-
- return EXT_PASS
-
- def before_update(self, mapper, connection, instance):
- """Receive an object instance before that instance is UPDATEed."""
-
- return EXT_PASS
-
- def after_update(self, mapper, connection, instance):
- """Receive an object instance after that instance is UPDATEed."""
-
- return EXT_PASS
-
- def after_insert(self, mapper, connection, instance):
- """Receive an object instance after that instance is INSERTed."""
-
- return EXT_PASS
-
- def before_delete(self, mapper, connection, instance):
- """Receive an object instance before that instance is DELETEed."""
-
- return EXT_PASS
-
- def after_delete(self, mapper, connection, instance):
- """Receive an object instance after that instance is DELETEed."""
-
- return EXT_PASS
-
-class _ExtensionCarrier(MapperExtension):
- def __init__(self):
- self.__elements = []
-
- def __iter__(self):
- return iter(self.__elements)
+ def populate_instance(self, selectcontext, instance, row, ispostselect=None, **flags):
+ """populate an instance from a result row."""
+
+ selectcontext.stack.push_mapper(self)
+ populators = selectcontext.attributes.get(('instance_populators', self, selectcontext.stack.snapshot(), ispostselect), None)
+ if populators is None:
+ populators = []
+ post_processors = []
+ for prop in self.__props.values():
+ (pop, post_proc) = prop.create_row_processor(selectcontext, self, row)
+ if pop is not None:
+ populators.append(pop)
+ if post_proc is not None:
+ post_processors.append(post_proc)
+
+ poly_select_loader = self._get_poly_select_loader(selectcontext, row)
+ if poly_select_loader is not None:
+ post_processors.append(poly_select_loader)
+
+ selectcontext.attributes[('instance_populators', self, selectcontext.stack.snapshot(), ispostselect)] = populators
+ selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors
+
+ for p in populators:
+ p(instance, row, ispostselect=ispostselect, **flags)
- def insert(self, extension):
- """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
-
- self.__elements.insert(0, extension)
-
- def append(self, extension):
- """Append a MapperExtension at the end of this ExtensionCarrier's list."""
-
- self.__elements.append(extension)
-
- def get_session(self, *args, **kwargs):
- return self._do('get_session', *args, **kwargs)
-
- def load(self, *args, **kwargs):
- return self._do('load', *args, **kwargs)
-
- def get(self, *args, **kwargs):
- return self._do('get', *args, **kwargs)
-
- def get_by(self, *args, **kwargs):
- return self._do('get_by', *args, **kwargs)
-
- def select_by(self, *args, **kwargs):
- return self._do('select_by', *args, **kwargs)
-
- def select(self, *args, **kwargs):
- return self._do('select', *args, **kwargs)
-
- def translate_row(self, *args, **kwargs):
- return self._do('translate_row', *args, **kwargs)
-
- def create_instance(self, *args, **kwargs):
- return self._do('create_instance', *args, **kwargs)
-
- def append_result(self, *args, **kwargs):
- return self._do('append_result', *args, **kwargs)
-
- def populate_instance(self, *args, **kwargs):
- return self._do('populate_instance', *args, **kwargs)
-
- def before_insert(self, *args, **kwargs):
- return self._do('before_insert', *args, **kwargs)
-
- def before_update(self, *args, **kwargs):
- return self._do('before_update', *args, **kwargs)
-
- def after_update(self, *args, **kwargs):
- return self._do('after_update', *args, **kwargs)
+ selectcontext.stack.pop()
+
+ if self.non_primary:
+ selectcontext.attributes[('populating_mapper', instance)] = self
+
+ def _post_instance(self, selectcontext, instance):
+ post_processors = selectcontext.attributes[('post_processors', self, None)]
+ for p in post_processors:
+ p(instance)
+
+ def _get_poly_select_loader(self, selectcontext, row):
+ # 'select' or 'union'+col not present
+ (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
+ if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred':
+ return
+
+ cond, param_names = self._deferred_inheritance_condition(needs_tables)
+ statement = sql.select(needs_tables, cond, use_labels=True)
+ def post_execute(instance, **flags):
+ self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance))
- def after_insert(self, *args, **kwargs):
- return self._do('after_insert', *args, **kwargs)
+ identitykey = self.identity_key_from_instance(instance)
- def before_delete(self, *args, **kwargs):
- return self._do('before_delete', *args, **kwargs)
+ params = {}
+ for c in param_names:
+ params[c.name] = self.get_attr_by_column(instance, c)
+ row = selectcontext.session.connection(self).execute(statement, **params).fetchone()
+ self.populate_instance(selectcontext, instance, row, **{'isnew':False, 'instancekey':identitykey, 'ispostselect':True})
- def after_delete(self, *args, **kwargs):
- return self._do('after_delete', *args, **kwargs)
+ return post_execute
+
+Mapper.logger = logging.class_logger(Mapper)
- def _do(self, funcname, *args, **kwargs):
- for elem in self.__elements:
- ret = getattr(elem, funcname)(*args, **kwargs)
- if ret is not EXT_PASS:
- return ret
- else:
- return EXT_PASS
-class ExtensionOption(MapperOption):
- def __init__(self, ext):
- self.ext = ext
- def process_query(self, query):
- query.extension.append(self.ext)
class ClassKey(object):
"""Key a class and an entity name to a mapper, via the mapper_registry."""
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index a00a35ab6..6ce9fd706 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -15,8 +15,11 @@ from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import sets, random
-from sqlalchemy.orm.interfaces import *
+import operator
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator
+from sqlalchemy.exceptions import ArgumentError
+
+__all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef']
class ColumnProperty(StrategizedProperty):
"""Describes an object attribute that corresponds to a table column."""
@@ -31,17 +34,27 @@ class ColumnProperty(StrategizedProperty):
self.columns = list(columns)
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
-
+ self.comparator = ColumnProperty.ColumnComparator(self)
+ # sanity check
+ for col in columns:
+ if not hasattr(col, 'name'):
+ if hasattr(col, 'label'):
+ raise ArgumentError('ColumnProperties must be named for the mapper to work with them. Try .label() to fix this')
+ raise ArgumentError('%r is not a valid candidate for ColumnProperty' % col)
+
def create_strategy(self):
if self.deferred:
return strategies.DeferredColumnLoader(self)
else:
return strategies.ColumnLoader(self)
-
- def getattr(self, object):
+
+ def copy(self):
+ return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
+
+ def getattr(self, object, column):
return getattr(object, self.key)
- def setattr(self, object, value):
+ def setattr(self, object, value, column):
setattr(object, self.key, value)
def get_history(self, obj, passive=False):
@@ -50,19 +63,69 @@ class ColumnProperty(StrategizedProperty):
def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
- def compare(self, value):
- return self.columns[0] == value
+ def get_col_value(self, column, value):
+ return value
+
+ class ColumnComparator(PropComparator):
+ def clause_element(self):
+ return self.prop.columns[0]
+
+ def operate(self, op, other):
+ return op(self.prop.columns[0], other)
+
+ def reverse_operate(self, op, other):
+ col = self.prop.columns[0]
+ return op(col._bind_param(other), col)
+
ColumnProperty.logger = logging.class_logger(ColumnProperty)
mapper.ColumnProperty = ColumnProperty
+class CompositeProperty(ColumnProperty):
+ """subclasses ColumnProperty to provide composite type support."""
+
+ def __init__(self, class_, *columns, **kwargs):
+ super(CompositeProperty, self).__init__(*columns, **kwargs)
+ self.composite_class = class_
+ self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator(self))
+
+ def copy(self):
+ return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
+
+ def getattr(self, object, column):
+ obj = getattr(object, self.key)
+ return self.get_col_value(column, obj)
+
+ def setattr(self, object, value, column):
+ obj = getattr(object, self.key, None)
+ if obj is None:
+ obj = self.composite_class(*[None for c in self.columns])
+ for a, b in zip(self.columns, value.__colset__()):
+ if a is column:
+ setattr(obj, b, value)
+
+ def get_col_value(self, column, value):
+ for a, b in zip(self.columns, value.__colset__()):
+ if a is column:
+ return b
+
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return sql.and_(*[a==None for a in self.prop.columns])
+ else:
+ return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())])
+
+ def __ne__(self, other):
+ return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())])
+
class PropertyLoader(StrategizedProperty):
"""Describes an object property that holds a single item or list
of items that correspond to a related database table.
"""
- def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True):
+ def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None):
self.uselist = uselist
self.argument = argument
self.entity_name = entity_name
@@ -80,7 +143,9 @@ class PropertyLoader(StrategizedProperty):
self.remote_side = util.to_set(remote_side)
self.enable_typechecks = enable_typechecks
self._parent_join_cache = {}
-
+ self.comparator = PropertyLoader.Comparator(self)
+ self.join_depth = join_depth
+
if cascade is not None:
self.cascade = mapperutil.CascadeOptions(cascade)
else:
@@ -91,7 +156,7 @@ class PropertyLoader(StrategizedProperty):
self.association = association
self.order_by = order_by
- self.attributeext = attributeext
+ self.attributeext=attributeext
if isinstance(backref, str):
# propigate explicitly sent primary/secondary join conditions to the BackRef object if
# just a string was sent
@@ -104,9 +169,96 @@ class PropertyLoader(StrategizedProperty):
self.backref = backref
self.is_backref = is_backref
- def compare(self, value):
- return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
-
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return ~sql.exists([1], self.prop.primaryjoin)
+ elif self.prop.uselist:
+ if not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.")
+ else:
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ clauses = []
+ for o in other:
+ clauses.append(
+ sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))]))
+ )
+ return sql.and_(*clauses)
+ else:
+ return self.prop._optimized_compare(other)
+
+ def any(self, criterion=None, **kwargs):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ for k in kwargs:
+ crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+ return sql.exists([1], j & criterion)
+
+ def has(self, criterion=None, **kwargs):
+ if self.prop.uselist:
+ raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ for k in kwargs:
+ crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+ return sql.exists([1], j & criterion)
+
+ def contains(self, other):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==")
+ clause = self.prop._optimized_compare(other)
+
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+
+ clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+ return clause
+
+ def __ne__(self, other):
+ if self.prop.uselist and not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+
+ def compare(self, op, value, value_is_parent=False):
+ if op == operator.eq:
+ if value is None:
+ return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
+ else:
+ return self._optimized_compare(value, value_is_parent=value_is_parent)
+ else:
+ return op(self.comparator, value)
+
+ def _optimized_compare(self, value, value_is_parent=False):
+ # optimized operation for ==, uses a lazy clause.
+ (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
+ bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+ class Visitor(sql.ClauseVisitor):
+ def visit_bindparam(s, bindparam):
+ mapper = value_is_parent and self.parent or self.mapper
+ bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
+ Visitor().traverse(criterion)
+ return criterion
+
private = property(lambda s:s.cascade.delete_orphan)
def create_strategy(self):
@@ -127,12 +279,13 @@ class PropertyLoader(StrategizedProperty):
if childlist is None:
return
if self.uselist:
- # sets a blank list according to the correct list class
- dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+ # sets a blank collection according to the correct list class
+ dest_list = sessionlib.attribute_manager.init_collection(dest, self.key)
for current in list(childlist):
obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive)
if obj is not None:
- dest_list.append(obj)
+ #dest_list.append_without_event(obj)
+ dest_list.append_with_event(obj)
else:
current = list(childlist)[0]
if current is not None:
@@ -267,7 +420,7 @@ class PropertyLoader(StrategizedProperty):
if len(self.foreign_keys):
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
if binary.left in self.foreign_keys:
self._opposite_side.add(binary.right)
@@ -280,7 +433,7 @@ class PropertyLoader(StrategizedProperty):
self.foreign_keys = util.Set()
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
# this check is for when the user put the "view_only" flag on and has tables that have nothing
@@ -362,16 +515,13 @@ class PropertyLoader(StrategizedProperty):
"argument." % (str(self)))
def _determine_remote_side(self):
- if len(self.remote_side):
- return
- self.remote_side = util.Set()
+ if not len(self.remote_side):
+ if self.direction is sync.MANYTOONE:
+ self.remote_side = util.Set(self._opposite_side)
+ elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
+ self.remote_side = util.Set(self.foreign_keys)
- if self.direction is sync.MANYTOONE:
- for c in self._opposite_side:
- self.remote_side.add(c)
- elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
- for c in self.foreign_keys:
- self.remote_side.add(c)
+ self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side)
def _create_polymorphic_joins(self):
# get ready to create "polymorphic" primary/secondary join clauses.
@@ -383,27 +533,26 @@ class PropertyLoader(StrategizedProperty):
# as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
# first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
# several "equivalent" columns (such as parent/child fk cols) into just one column.
- target_equivalents = self.mapper._get_inherited_column_equivalents()
+ target_equivalents = self.mapper._get_equivalent_columns()
+
# if the target mapper loads polymorphically, adapt the clauses to the target's selectable
if self.loads_polymorphic:
if self.secondaryjoin:
- self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container()
- sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin)
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
+ self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
+ self.polymorphic_primaryjoin = self.primaryjoin
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
self.polymorphic_secondaryjoin = None
# load "polymorphic" versions of the columns present in "remote_side" - this is
# important for lazy-clause generation which goes off the polymorphic target selectable
for c in list(self.remote_side):
- if self.secondary and c in self.secondary.columns:
+ if self.secondary and self.secondary.columns.contains_column(c):
continue
- for equiv in [c] + (c in target_equivalents and target_equivalents[c] or []):
+ for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []):
corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False)
if corr:
self.remote_side.add(corr)
@@ -411,8 +560,8 @@ class PropertyLoader(StrategizedProperty):
else:
raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table))
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
- self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None
+ self.polymorphic_primaryjoin = self.primaryjoin
+ self.polymorphic_secondaryjoin = self.secondaryjoin
def _post_init(self):
if logging.is_info_enabled(self.logger):
@@ -450,22 +599,20 @@ class PropertyLoader(StrategizedProperty):
def _is_self_referential(self):
return self.parent.mapped_table is self.target or self.parent.select_table is self.target
- def get_join(self, parent, primary=True, secondary=True):
+ def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True):
try:
- return self._parent_join_cache[(parent, primary, secondary)]
+ return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)]
except KeyError:
- parent_equivalents = parent._get_inherited_column_equivalents()
- primaryjoin = self.polymorphic_primaryjoin.copy_container()
- if self.secondaryjoin is not None:
- secondaryjoin = self.polymorphic_secondaryjoin.copy_container()
- else:
- secondaryjoin = None
- if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
- elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
- elif self.secondaryjoin:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+ parent_equivalents = parent._get_equivalent_columns()
+ secondaryjoin = self.polymorphic_secondaryjoin
+ if polymorphic_parent:
+ # adapt the "parent" side of our join condition to the "polymorphic" select of the parent
+ if self.direction is sync.ONETOMANY:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+ elif self.direction is sync.MANYTOONE:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+ elif self.secondaryjoin:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
if secondaryjoin is not None:
if secondary and not primary:
@@ -476,7 +623,7 @@ class PropertyLoader(StrategizedProperty):
j = primaryjoin
else:
j = primaryjoin
- self._parent_join_cache[(parent, primary, secondary)] = j
+ self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j
return j
def register_dependencies(self, uowcommit):
@@ -501,7 +648,7 @@ class BackRef(object):
# try to set a LazyLoader on our mapper referencing the parent mapper
mapper = prop.mapper.primary_mapper()
- if not mapper.props.has_key(self.key):
+ if not mapper.get_property(self.key, raiseerr=False) is not None:
pj = self.kwargs.pop('primaryjoin', None)
sj = self.kwargs.pop('secondaryjoin', None)
# the backref property is set on the primary mapper
@@ -512,26 +659,26 @@ class BackRef(object):
backref=prop.key, is_backref=True,
**self.kwargs)
mapper._compile_property(self.key, relation);
- elif not isinstance(mapper.props[self.key], PropertyLoader):
+ elif not isinstance(mapper.get_property(self.key), PropertyLoader):
raise exceptions.ArgumentError(
"Can't create backref '%s' on mapper '%s'; an incompatible "
"property of that name already exists" % (self.key, str(mapper)))
else:
# else set one of us as the "backreference"
parent = prop.parent.primary_mapper()
- if parent.class_ is not mapper.props[self.key]._get_target_class():
+ if parent.class_ is not mapper.get_property(self.key)._get_target_class():
raise exceptions.ArgumentError(
"Backrefs do not match: backref '%s' expects to connect to %s, "
"but found a backref already connected to %s" %
- (self.key, str(parent.class_), str(mapper.props[self.key].mapper.class_)))
- if not mapper.props[self.key].is_backref:
+ (self.key, str(parent.class_), str(mapper.get_property(self.key).mapper.class_)))
+ if not mapper.get_property(self.key).is_backref:
prop.is_backref=True
if not prop.viewonly:
prop._dependency_processor.is_backref=True
# reverse_property used by dependencies.ManyToManyDP to check
# association table operations
- prop.reverse_property = mapper.props[self.key]
- mapper.props[self.key].reverse_property = prop
+ prop.reverse_property = mapper.get_property(self.key)
+ mapper.get_property(self.key).reverse_property = prop
def get_extension(self):
"""Return an attribute extension to use with this backreference."""
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index d51fd75c3..284653b5c 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -4,89 +4,49 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
-from sqlalchemy.orm import mapper, class_mapper, object_mapper
-from sqlalchemy.orm.interfaces import OperationContext, SynonymProperty
-import random
+from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy.orm import mapper, object_mapper
+from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.interfaces import OperationContext, LoaderStack
+import operator
__all__ = ['Query', 'QueryContext', 'SelectionContext']
class Query(object):
- """Encapsulates the object-fetching operations provided by Mappers.
+ """Encapsulates the object-fetching operations provided by Mappers."""
- Note that this particular version of Query contains the 0.3 API as well as most of the
- 0.4 API for forwards compatibility. A large part of the API here is deprecated (but still present)
- in the 0.4 series.
- """
-
- def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, with_options=None, extension=None, **kwargs):
+ def __init__(self, class_or_mapper, session=None, entity_name=None):
if isinstance(class_or_mapper, type):
self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
else:
self.mapper = class_or_mapper.compile()
- self.with_options = with_options or []
self.select_mapper = self.mapper.get_select_mapper().compile()
- self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
- self.lockmode = lockmode
- self.extension = mapper._ExtensionCarrier()
- if extension is not None:
- self.extension.append(extension)
- self.extension.append(self.mapper.extension)
- self.is_polymorphic = self.mapper is not self.select_mapper
+
self._session = session
- if not hasattr(self.mapper, '_get_clause'):
- _get_clause = sql.and_()
- for primary_key in self.primary_key_columns:
- _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
- self.mapper._get_clause = _get_clause
+ self._with_options = []
+ self._lockmode = None
+ self._extension = self.mapper.extension.copy()
self._entities = []
- self._get_clause = self.mapper._get_clause
-
- self._order_by = kwargs.pop('order_by', False)
- self._group_by = kwargs.pop('group_by', False)
- self._distinct = kwargs.pop('distinct', False)
- self._offset = kwargs.pop('offset', None)
- self._limit = kwargs.pop('limit', None)
- self._criterion = None
+ self._order_by = False
+ self._group_by = False
+ self._distinct = False
+ self._offset = None
+ self._limit = None
+ self._statement = None
self._params = {}
- self._col = None
- self._func = None
+ self._criterion = None
+ self._column_aggregate = None
self._joinpoint = self.mapper
+ self._aliases = None
+ self._alias_ids = {}
self._from_obj = [self.table]
- self._statement = None
-
- for opt in util.flatten_iterator(self.with_options):
- opt.process_query(self)
+ self._populate_existing = False
+ self._version_check = False
def _clone(self):
- # yes, a little embarassing here.
- # go look at 0.4 for the simple version.
q = Query.__new__(Query)
- q.mapper = self.mapper
- q.select_mapper = self.select_mapper
- q._order_by = self._order_by
- q._distinct = self._distinct
- q._entities = list(self._entities)
- q.always_refresh = self.always_refresh
- q.with_options = list(self.with_options)
- q._session = self.session
- q.is_polymorphic = self.is_polymorphic
- q.lockmode = self.lockmode
- q.extension = mapper._ExtensionCarrier()
- for ext in self.extension:
- q.extension.append(ext)
- q._offset = self._offset
- q._limit = self._limit
- q._params = self._params
- q._group_by = self._group_by
- q._get_clause = self._get_clause
- q._from_obj = list(self._from_obj)
- q._joinpoint = self._joinpoint
- q._criterion = self._criterion
- q._statement = self._statement
- q._col = self._col
- q._func = self._func
+ q.__dict__ = self.__dict__.copy()
return q
def _get_session(self):
@@ -96,7 +56,7 @@ class Query(object):
return self._session
table = property(lambda s:s.select_mapper.mapped_table)
- primary_key_columns = property(lambda s:s.select_mapper.pks_by_table[s.select_mapper.mapped_table])
+ primary_key_columns = property(lambda s:s.select_mapper.primary_key)
session = property(_get_session)
def get(self, ident, **kwargs):
@@ -108,13 +68,20 @@ class Query(object):
columns.
"""
- ret = self.extension.get(self, ident, **kwargs)
+ ret = self._extension.get(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
- key = self.mapper.identity_key(ident)
+
+ # convert composite types to individual args
+ # TODO: account for the order of columns in the
+ # ColumnProperty it corresponds to
+ if hasattr(ident, '__colset__'):
+ ident = ident.__colset__()
+
+ key = self.mapper.identity_key_from_primary_key(ident)
return self._get(key, ident, **kwargs)
- def load(self, ident, **kwargs):
+ def load(self, ident, raiseerr=True, **kwargs):
"""Return an instance of the object based on the given
identifier.
@@ -125,304 +92,14 @@ class Query(object):
columns.
"""
- ret = self.extension.load(self, ident, **kwargs)
+ ret = self._extension.load(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
- key = self.mapper.identity_key(ident)
+ key = self.mapper.identity_key_from_primary_key(ident)
instance = self._get(key, ident, reload=True, **kwargs)
- if instance is None:
+ if instance is None and raiseerr:
raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
return instance
-
- def get_by(self, *args, **params):
- """Like ``select_by()``, but only return the first
- as a scalar, or None if no object found.
- Synonymous with ``selectfirst_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- ret = self.extension.get_by(self, *args, **params)
- if ret is not mapper.EXT_PASS:
- return ret
- x = self.select_whereclause(self.join_by(*args, **params), limit=1)
- if x:
- return x[0]
- else:
- return None
-
- def select_by(self, *args, **params):
- """Return an array of object instances based on the given
- clauses and key/value criterion.
-
- \*args
- a list of zero or more ``ClauseElements`` which will be
- connected by ``AND`` operators.
-
- \**params
- a set of zero or more key/value parameters which
- are converted into ``ClauseElements``. the keys are mapped to
- property or column names mapped by this mapper's Table, and
- the values are coerced into a ``WHERE`` clause separated by
- ``AND`` operators. If the local property/column names dont
- contain the key, a search will be performed against this
- mapper's immediate list of relations as well, forming the
- appropriate join conditions if a matching property is located.
-
- if the located property is a column-based property, the comparison
- value should be a scalar with an appropriate type. If the
- property is a relationship-bound property, the comparison value
- should be an instance of the related class.
-
- E.g.::
-
- result = usermapper.select_by(user_name = 'fred')
-
- this method is deprecated in 0.4.
- """
-
- ret = self.extension.select_by(self, *args, **params)
- if ret is not mapper.EXT_PASS:
- return ret
- return self.select_whereclause(self.join_by(*args, **params))
-
- def join_by(self, *args, **params):
- """Return a ``ClauseElement`` representing the ``WHERE``
- clause that would normally be sent to ``select_whereclause()``
- by ``select_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- return self._join_by(args, params)
-
-
- def join_to(self, key):
- """Given the key name of a property, will recursively descend
- through all child properties from this Query's mapper to
- locate the property, and will return a ClauseElement
- representing a join from this Query's mapper to the endmost
- mapper.
-
- this method is deprecated in 0.4.
- """
-
- [keys, p] = self._locate_prop(key)
- return self.join_via(keys)
-
- def join_via(self, keys):
- """Given a list of keys that represents a path from this
- Query's mapper to a related mapper based on names of relations
- from one mapper to the next, return a ClauseElement
- representing a join from this Query's mapper to the endmost
- mapper.
-
- this method is deprecated in 0.4.
- """
-
- mapper = self.mapper
- clause = None
- for key in keys:
- prop = mapper.get_property(key, resolve_synonyms=True)
- if clause is None:
- clause = prop.get_join(mapper)
- else:
- clause &= prop.get_join(mapper)
- mapper = prop.mapper
-
- return clause
-
- def selectfirst_by(self, *args, **params):
- """Like ``select_by()``, but only return the first
- as a scalar, or None if no object found.
- Synonymous with ``get_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- return self.get_by(*args, **params)
-
- def selectone_by(self, *args, **params):
- """Like ``selectfirst_by()``, but throws an error if not
- exactly one result was returned.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- ret = self.select_whereclause(self.join_by(*args, **params), limit=2)
- if len(ret) == 1:
- return ret[0]
- elif len(ret) == 0:
- raise exceptions.InvalidRequestError('No rows returned for selectone_by')
- else:
- raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by')
-
- def count_by(self, *args, **params):
- """Return the count of instances based on the given clauses
- and key/value criterion.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- return self.count(self.join_by(*args, **params))
-
- def selectfirst(self, arg=None, **kwargs):
- """Query for a single instance using the given criterion.
-
- Arguments are the same as ``select()``. In the case that
- the given criterion represents ``WHERE`` criterion only,
- LIMIT 1 is applied to the fully generated statement.
-
- this method is deprecated in 0.4.
- """
-
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- ret = self.select_statement(arg, **kwargs)
- else:
- kwargs['limit'] = 1
- ret = self.select_whereclause(whereclause=arg, **kwargs)
- if ret:
- return ret[0]
- else:
- return None
-
- def selectone(self, arg=None, **kwargs):
- """Query for a single instance using the given criterion.
-
- Unlike ``selectfirst``, this method asserts that only one
- row exists. In the case that the given criterion represents
- ``WHERE`` criterion only, LIMIT 2 is applied to the fully
- generated statement.
-
- this method is deprecated in 0.4.
- """
-
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- ret = self.select_statement(arg, **kwargs)
- else:
- kwargs['limit'] = 2
- ret = self.select_whereclause(whereclause=arg, **kwargs)
- if len(ret) == 1:
- return ret[0]
- elif len(ret) == 0:
- raise exceptions.InvalidRequestError('No rows returned for selectone_by')
- else:
- raise exceptions.InvalidRequestError('Multiple rows returned for selectone')
-
- def select(self, arg=None, **kwargs):
- """Select instances of the object from the database.
-
- `arg` can be any ClauseElement, which will form the criterion
- with which to load the objects.
-
- For more advanced usage, arg can also be a Select statement
- object, which will be executed and its resulting rowset used
- to build new object instances.
-
- In this case, the developer must ensure that an adequate set
- of columns exists in the rowset with which to build new object
- instances.
-
- this method is deprecated in 0.4.
- """
-
- ret = self.extension.select(self, arg=arg, **kwargs)
- if ret is not mapper.EXT_PASS:
- return ret
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- return self.select_statement(arg, **kwargs)
- else:
- return self.select_whereclause(whereclause=arg, **kwargs)
-
- def select_whereclause(self, whereclause=None, params=None, **kwargs):
- """Given a ``WHERE`` criterion, create a ``SELECT`` statement,
- execute and return the resulting instances.
-
- this method is deprecated in 0.4.
- """
- statement = self.compile(whereclause, **kwargs)
- return self._select_statement(statement, params=params)
-
- def count(self, whereclause=None, params=None, **kwargs):
- """Given a ``WHERE`` criterion, create a ``SELECT COUNT``
- statement, execute and return the resulting count value.
-
- the additional arguments to this method are is deprecated in 0.4.
-
- """
- if self._criterion:
- if whereclause is not None:
- whereclause = sql.and_(self._criterion, whereclause)
- else:
- whereclause = self._criterion
- from_obj = kwargs.pop('from_obj', self._from_obj)
- kwargs.setdefault('distinct', self._distinct)
-
- alltables = []
- for l in [sql_util.TableFinder(x) for x in from_obj]:
- alltables += l
-
- if self.table not in alltables:
- from_obj.append(self.table)
- if self._nestable(**kwargs):
- s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count()
- else:
- primary_key = self.primary_key_columns
- s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs)
- return self.session.scalar(self.mapper, s, params=params)
-
- def select_statement(self, statement, **params):
- """Given a ``ClauseElement``-based statement, execute and
- return the resulting instances.
-
- this method is deprecated in 0.4.
- """
-
- return self._select_statement(statement, params=params)
-
- def select_text(self, text, **params):
- """Given a literal string-based statement, execute and return
- the resulting instances.
-
- this method is deprecated in 0.4. use from_statement() instead.
- """
-
- t = sql.text(text)
- return self.execute(t, params=params)
-
- def _with_lazy_criterion(cls, instance, prop, reverse=False):
- """extract query criterion from a LazyLoader strategy given a Mapper,
- source persisted/detached instance and PropertyLoader.
-
- """
-
- from sqlalchemy.orm import strategies
- (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(prop, reverse_direction=reverse)
- bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
-
- class Visitor(sql.ClauseVisitor):
- def visit_bindparam(self, bindparam):
- mapper = reverse and prop.mapper or prop.parent
- bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
- Visitor().traverse(criterion)
- return criterion
- _with_lazy_criterion = classmethod(_with_lazy_criterion)
-
def query_from_parent(cls, instance, property, **kwargs):
"""return a newly constructed Query object, with criterion corresponding to
@@ -445,9 +122,25 @@ class Query(object):
mapper = object_mapper(instance)
prop = mapper.get_property(property, resolve_synonyms=True)
target = prop.mapper
- criterion = cls._with_lazy_criterion(instance, prop)
+ criterion = prop.compare(operator.eq, instance, value_is_parent=True)
return Query(target, **kwargs).filter(criterion)
query_from_parent = classmethod(query_from_parent)
+
+ def populate_existing(self):
+ """return a Query that will refresh all instances loaded.
+
+ this includes all entities accessed from the database, including
+ secondary entities, eagerly-loaded collection items.
+
+ All changes present on entities which are already present in the session will
+ be reset and the entities will all be marked "clean".
+
+ This is essentially the en-masse version of load().
+ """
+
+ q = self._clone()
+ q._populate_existing = True
+ return q
def with_parent(self, instance, property=None):
"""add a join criterion corresponding to a relationship to the given parent instance.
@@ -474,9 +167,9 @@ class Query(object):
raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
else:
prop = mapper.get_property(property, resolve_synonyms=True)
- return self.filter(Query._with_lazy_criterion(instance, prop))
+ return self.filter(prop.compare(operator.eq, instance, value_is_parent=True))
- def add_entity(self, entity):
+ def add_entity(self, entity, alias=None, id=None):
"""add a mapped entity to the list of result columns to be returned.
This will have the effect of all result-returning methods returning a tuple
@@ -492,12 +185,25 @@ class Query(object):
entity
a class or mapper which will be added to the results.
+ alias
+ a sqlalchemy.sql.Alias object which will be used to select rows. this
+ will match the usage of the given Alias in filter(), order_by(), etc. expressions
+
+ id
+ a string ID matching that given to query.join() or query.outerjoin(); rows will be
+ selected from the aliased join created via those methods.
"""
q = self._clone()
- q._entities.append(entity)
+
+ if isinstance(entity, type):
+ entity = mapper.class_mapper(entity)
+ if alias is not None:
+ alias = mapperutil.AliasedClauses(entity.mapped_table, alias=alias)
+
+ q._entities = q._entities + [(entity, alias, id)]
return q
- def add_column(self, column):
+ def add_column(self, column, id=None):
"""add a SQL ColumnElement to the list of result columns to be returned.
This will have the effect of all result-returning methods returning a tuple
@@ -517,51 +223,56 @@ class Query(object):
"""
q = self._clone()
-
+
# alias non-labeled column elements.
- # TODO: make the generation deterministic
if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
- column = column.label("anon_" + hex(random.randint(0, 65535))[2:])
-
- q._entities.append(column)
+ column = column.label(None)
+
+ q._entities = q._entities + [(column, None, id)]
return q
- def options(self, *args, **kwargs):
+ def options(self, *args):
"""Return a new Query object, applying the given list of
MapperOptions.
"""
+
q = self._clone()
- for opt in util.flatten_iterator(args):
- q.with_options.append(opt)
+ opts = [o for o in util.flatten_iterator(args)]
+ q._with_options = q._with_options + opts
+ for opt in opts:
opt.process_query(q)
return q
def with_lockmode(self, mode):
"""Return a new Query object with the specified locking mode."""
q = self._clone()
- q.lockmode = mode
+ q._lockmode = mode
return q
def params(self, **kwargs):
"""add values for bind parameters which may have been specified in filter()."""
-
+
q = self._clone()
q._params = q._params.copy()
q._params.update(kwargs)
return q
-
+
def filter(self, criterion):
"""apply the given filtering criterion to the query and return the newly resulting ``Query``
the criterion is any sql.ClauseElement applicable to the WHERE clause of a select.
"""
-
+
if isinstance(criterion, basestring):
criterion = sql.text(criterion)
-
+
if criterion is not None and not isinstance(criterion, sql.ClauseElement):
raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
-
+
+
+ if self._aliases is not None:
+ criterion = self._aliases.adapt_clause(criterion)
+
q = self._clone()
if q._criterion is not None:
q._criterion = q._criterion & criterion
@@ -569,23 +280,36 @@ class Query(object):
q._criterion = criterion
return q
- def filter_by(self, *args, **kwargs):
- """apply the given filtering criterion to the query and return the newly resulting ``Query``
+ def filter_by(self, **kwargs):
+ """apply the given filtering criterion to the query and return the newly resulting ``Query``."""
- The criterion is constructed in the same way as the
- ``select_by()`` method.
- """
- return self.filter(self._join_by(args, kwargs, start=self._joinpoint))
+ #import properties
+
+ alias = None
+ join = None
+ clause = None
+ joinpoint = self._joinpoint
- def _join_to(self, prop, outerjoin=False, start=None):
- if start is None:
- start = self._joinpoint
+ for key, value in kwargs.iteritems():
+ prop = joinpoint.get_property(key, resolve_synonyms=True)
+ c = prop.compare(operator.eq, value)
- if isinstance(prop, list):
- keys = prop
+ if alias is not None:
+ sql_util.ClauseAdapter(alias).traverse(c)
+ if clause is None:
+ clause = c
+ else:
+ clause &= c
+
+ if join is not None:
+ return self.select_from(join).filter(clause)
else:
- [keys,p] = self._locate_prop(prop, start=start)
+ return self.filter(clause)
+ def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
+ if start is None:
+ start = self._joinpoint
+
clause = self._from_obj[-1]
currenttables = [clause]
@@ -594,101 +318,56 @@ class Query(object):
currenttables.append(join.left)
currenttables.append(join.right)
FindJoinedTables().traverse(clause)
-
+
mapper = start
- for key in keys:
+ alias = self._aliases
+ for key in util.to_list(keys):
prop = mapper.get_property(key, resolve_synonyms=True)
- if prop._is_self_referential():
- raise exceptions.InvalidRequestError("Self-referential query on '%s' property must be constructed manually using an Alias object for the related table." % str(prop))
- # dont re-join to a table already in our from objects
- if prop.select_table not in currenttables:
- if outerjoin:
- if prop.secondary:
- clause = clause.outerjoin(prop.secondary, prop.get_join(mapper, primary=True, secondary=False))
- clause = clause.outerjoin(prop.select_table, prop.get_join(mapper, primary=False))
+ if prop._is_self_referential() and not create_aliases:
+ raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop))
+
+ if prop.select_table not in currenttables or create_aliases:
+ if prop.secondary:
+ if create_aliases:
+ alias = mapperutil.PropertyAliasedClauses(prop,
+ prop.get_join(mapper, primary=True, secondary=False),
+ prop.get_join(mapper, primary=False, secondary=True),
+ alias
+ )
+ clause = clause.join(alias.secondary, alias.primaryjoin, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
else:
- clause = clause.outerjoin(prop.select_table, prop.get_join(mapper))
+ clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False), isouter=outerjoin)
+ clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False), isouter=outerjoin)
else:
- if prop.secondary:
- clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False))
- clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False))
+ if create_aliases:
+ alias = mapperutil.PropertyAliasedClauses(prop,
+ prop.get_join(mapper, primary=True, secondary=False),
+ None,
+ alias
+ )
+ clause = clause.join(alias.alias, alias.primaryjoin, isouter=outerjoin)
else:
- clause = clause.join(prop.select_table, prop.get_join(mapper))
- elif prop.secondary is not None and prop.secondary not in currenttables:
+ clause = clause.join(prop.select_table, prop.get_join(mapper), isouter=outerjoin)
+ elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables:
# TODO: this check is not strong enough for different paths to the same endpoint which
# does not use secondary tables
- raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use explicit `Alias` objects." % prop.key)
+ raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key)
mapper = prop.mapper
- return (clause, mapper)
-
- def _join_by(self, args, params, start=None):
- """Return a ``ClauseElement`` representing the ``WHERE``
- clause that would normally be sent to ``select_whereclause()``
- by ``select_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
- """
- import properties
-
- clause = None
- for arg in args:
- if clause is None:
- clause = arg
- else:
- clause &= arg
-
- for key, value in params.iteritems():
- (keys, prop) = self._locate_prop(key, start=start)
- if isinstance(prop, properties.PropertyLoader):
- c = self._with_lazy_criterion(value, prop, True) & self.join_via(keys[:-1])
- else:
- c = prop.compare(value) & self.join_via(keys)
- if clause is None:
- clause = c
- else:
- clause &= c
- return clause
-
- def _locate_prop(self, key, start=None):
- import properties
- keys = []
- seen = util.Set()
- def search_for_prop(mapper_):
- if mapper_ in seen:
- return None
- seen.add(mapper_)
- if mapper_.props.has_key(key):
- prop = mapper_.get_property(key, resolve_synonyms=True)
- if isinstance(prop, properties.PropertyLoader):
- keys.insert(0, prop.key)
- return prop
- else:
- for prop in mapper_.iterate_properties:
- if not isinstance(prop, properties.PropertyLoader):
- continue
- x = search_for_prop(prop.mapper)
- if x:
- keys.insert(0, prop.key)
- return x
- else:
- return None
- p = search_for_prop(start or self.mapper)
- if p is None:
- raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
- return [keys, p]
+ if create_aliases:
+ return (clause, mapper, alias)
+ else:
+ return (clause, mapper, None)
def _generative_col_aggregate(self, col, func):
"""apply the given aggregate function to the query and return the newly
resulting ``Query``.
"""
- if self._col is not None or self._func is not None:
+ if self._column_aggregate is not None:
raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
q = self._clone()
- q._col = col
- q._func = func
+ q._column_aggregate = (col, func)
return q
def apply_min(self, col):
@@ -721,13 +400,13 @@ class Query(object):
For performance, only use subselect if `order_by` attribute is set.
"""
- ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj}
+ ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj}
if self._order_by is not False:
s1 = sql.select([col], self._criterion, **ops).alias('u')
- return sql.select([func(s1.corresponding_column(col))]).scalar()
+ return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar()
else:
- return sql.select([func(col)], self._criterion, **ops).scalar()
+ return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar()
def min(self, col):
"""Execute the SQL ``min()`` function against the given column."""
@@ -756,7 +435,7 @@ class Query(object):
if q._order_by is False:
q._order_by = util.to_list(criterion)
else:
- q._order_by.extend(util.to_list(criterion))
+ q._order_by = q._order_by + util.to_list(criterion)
return q
def group_by(self, criterion):
@@ -766,52 +445,62 @@ class Query(object):
if q._group_by is False:
q._group_by = util.to_list(criterion)
else:
- q._group_by.extend(util.to_list(criterion))
+ q._group_by = q._group_by + util.to_list(criterion)
return q
-
- def join(self, prop):
+
+ def join(self, prop, id=None, aliased=False, from_joinpoint=False):
"""create a join of this ``Query`` object's criterion
to a relationship and return the newly resulting ``Query``.
-
- 'prop' may be a string property name in which it is located
- in the same manner as keyword arguments in ``select_by``, or
- it may be a list of strings in which case the property is located
- by direct traversal of each keyname (i.e. like join_via()).
+
+ 'prop' may be a string property name or a list of string
+ property names.
"""
-
- q = self._clone()
- (clause, mapper) = self._join_to(prop, outerjoin=False)
- q._from_obj = [clause]
- q._joinpoint = mapper
- return q
- def outerjoin(self, prop):
+ return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint)
+
+ def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False):
"""create a left outer join of this ``Query`` object's criterion
to a relationship and return the newly resulting ``Query``.
- 'prop' may be a string property name in which it is located
- in the same manner as keyword arguments in ``select_by``, or
- it may be a list of strings in which case the property is located
- by direct traversal of each keyname (i.e. like join_via()).
+ 'prop' may be a string property name or a list of string
+ property names.
"""
+
+ return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint)
+
+ def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
+ (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
q = self._clone()
- (clause, mapper) = self._join_to(prop, outerjoin=True)
q._from_obj = [clause]
q._joinpoint = mapper
+ q._aliases = aliases
+
+ a = aliases
+ while a is not None:
+ q._alias_ids.setdefault(a.mapper, []).append(a)
+ q._alias_ids.setdefault(a.table, []).append(a)
+ q._alias_ids.setdefault(a.alias, []).append(a)
+ a = a.parentclauses
+
+ if id:
+ q._alias_ids[id] = aliases
return q
def reset_joinpoint(self):
"""return a new Query reset the 'joinpoint' of this Query reset
back to the starting mapper. Subsequent generative calls will
be constructed from the new joinpoint.
-
- This is an interim method which will not be needed with new behavior
- to be released in 0.4."""
-
+
+ Note that each call to join() or outerjoin() also starts from
+ the root.
+ """
+
q = self._clone()
q._joinpoint = q.mapper
+ q._aliases = None
return q
+
def select_from(self, from_obj):
"""Set the `from_obj` parameter of the query and return the newly
resulting ``Query``.
@@ -823,20 +512,6 @@ class Query(object):
new._from_obj = list(new._from_obj) + util.to_list(from_obj)
return new
- def __getattr__(self, key):
- if (key.startswith('select_by_')):
- key = key[10:]
- def foo(arg):
- return self.select_by(**{key:arg})
- return foo
- elif (key.startswith('get_by_')):
- key = key[7:]
- def foo(arg):
- return self.get_by(**{key:arg})
- return foo
- else:
- raise AttributeError(key)
-
def __getitem__(self, item):
if isinstance(item, slice):
start = item.start
@@ -884,99 +559,65 @@ class Query(object):
new._distinct = True
return new
- def list(self):
+ def all(self):
"""Return the results represented by this ``Query`` as a list.
This results in an execution of the underlying query.
-
- this method is deprecated in 0.4. use all() instead.
"""
-
return list(self)
-
- def one(self):
- """Return the first result of this ``Query``, raising an exception if more than one row exists.
-
- This results in an execution of the underlying query.
- this method is for forwards-compatibility with 0.4.
- """
-
- if self._col is None or self._func is None:
- ret = list(self[0:2])
-
- if len(ret) == 1:
- return ret[0]
- elif len(ret) == 0:
- raise exceptions.InvalidRequestError('No rows returned for one()')
- else:
- raise exceptions.InvalidRequestError('Multiple rows returned for one()')
- else:
- return self._col_aggregate(self._col, self._func)
-
+
+ def from_statement(self, statement):
+ if isinstance(statement, basestring):
+ statement = sql.text(statement)
+ q = self._clone()
+ q._statement = statement
+ return q
+
def first(self):
"""Return the first result of this ``Query``.
This results in an execution of the underlying query.
-
- this method is for forwards-compatibility with 0.4.
"""
- if self._col is None or self._func is None:
- ret = list(self[0:1])
- if len(ret) > 0:
- return ret[0]
- else:
- return None
+ if self._column_aggregate is not None:
+ return self._col_aggregate(*self._column_aggregate)
+
+ ret = list(self[0:1])
+ if len(ret) > 0:
+ return ret[0]
else:
- return self._col_aggregate(self._col, self._func)
+ return None
- def all(self):
- """Return the results represented by this ``Query`` as a list.
+ def one(self):
+ """Return the first result of this ``Query``, raising an exception if more than one row exists.
This results in an execution of the underlying query.
"""
- return self.list()
-
- def from_statement(self, statement):
- """execute a full select() statement, or literal textual string as a SELECT statement.
-
- this method is for forwards compatibility with 0.4.
- """
- if isinstance(statement, basestring):
- statement = sql.text(statement)
- q = self._clone()
- q._statement = statement
- return q
- def scalar(self):
- """Return the first result of this ``Query``.
+ if self._column_aggregate is not None:
+ return self._col_aggregate(*self._column_aggregate)
- This results in an execution of the underlying query.
-
- this method will be deprecated in 0.4; first() is added for
- forwards-compatibility.
- """
+ ret = list(self[0:2])
- return self.first()
+ if len(ret) == 1:
+ return ret[0]
+ elif len(ret) == 0:
+ raise exceptions.InvalidRequestError('No rows returned for one()')
+ else:
+ raise exceptions.InvalidRequestError('Multiple rows returned for one()')
def __iter__(self):
- return iter(self.select_whereclause())
-
- def execute(self, clauseelement, params=None, *args, **kwargs):
- """Execute the given ClauseElement-based statement against
- this Query's session/mapper, return the resulting list of
- instances.
-
- this method is deprecated in 0.4. Use from_statement() instead.
- """
-
- p = self._params
- if params is not None:
- p.update(params)
- result = self.session.execute(self.mapper, clauseelement, params=p)
+ statement = self.compile()
+ statement.use_labels = True
+ if self.session.autoflush:
+ self.session.flush()
+ return self._execute_and_instances(statement)
+
+ def _execute_and_instances(self, statement):
+ result = self.session.execute(statement, params=self._params, mapper=self.mapper)
try:
- return self.instances(result, **kwargs)
+ return iter(self.instances(result))
finally:
result.close()
@@ -984,50 +625,47 @@ class Query(object):
"""Return a list of mapped instances corresponding to the rows
in a given *cursor* (i.e. ``ResultProxy``).
- \*mappers_or_columns is an optional list containing one or more of
- classes, mappers, strings or sql.ColumnElements which will be
- applied to each row and added horizontally to the result set,
- which becomes a list of tuples. The first element in each tuple
- is the usual result based on the mapper represented by this
- ``Query``. Each additional element in the tuple corresponds to an
- entry in the \*mappers_or_columns list.
-
- For each element in \*mappers_or_columns, if the element is
- a mapper or mapped class, an additional class instance will be
- present in the tuple. If the element is a string or sql.ColumnElement,
- the corresponding result column from each row will be present in the tuple.
-
- Note that when \*mappers_or_columns is present, "uniquing" for the result set
- is *disabled*, so that the resulting tuples contain entities as they actually
- correspond. this indicates that multiple results may be present if this
- option is used.
+ The \*mappers_or_columns and \**kwargs arguments are deprecated.
+ To add instances or columns to the results, use add_entity()
+ and add_column().
"""
self.__log_debug("instances()")
session = self.session
- context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
+ kwargs.setdefault('populate_existing', self._populate_existing)
+ kwargs.setdefault('version_check', self._version_check)
+
+ context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs)
process = []
mappers_or_columns = tuple(self._entities) + mappers_or_columns
if mappers_or_columns:
- for m in mappers_or_columns:
+ for tup in mappers_or_columns:
+ if isinstance(tup, tuple):
+ (m, alias, alias_id) = tup
+ clauses = self._get_entity_clauses(tup)
+ else:
+ clauses = alias = alias_id = None
+ m = tup
if isinstance(m, type):
m = mapper.class_mapper(m)
if isinstance(m, mapper.Mapper):
def x(m):
+ row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row)
appender = []
def proc(context, row):
- if not m._instance(context, row, appender):
+ if not m._instance(context, row_adapter(row), appender):
appender.append(None)
process.append((proc, appender))
x(m)
- elif isinstance(m, sql.ColumnElement) or isinstance(m, basestring):
+ elif isinstance(m, (sql.ColumnElement, basestring)):
def y(m):
+ row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row)
res = []
def proc(context, row):
- res.append(row[m])
+ res.append(row_adapter(row)[m])
process.append((proc, res))
y(m)
result = []
@@ -1039,9 +677,12 @@ class Query(object):
for proc in process:
proc[0](context, row)
+ for instance in context.identity_map.values():
+ context.attributes.get(('populating_mapper', instance), object_mapper(instance))._post_instance(context, instance)
+
# store new stuff in the identity map
- for value in context.identity_map.values():
- session._register_persistent(value)
+ for instance in context.identity_map.values():
+ session._register_persistent(instance)
if mappers_or_columns:
return list(util.OrderedSet(zip(*([result] + [o[1] for o in process]))))
@@ -1050,8 +691,8 @@ class Query(object):
def _get(self, key, ident=None, reload=False, lockmode=None):
- lockmode = lockmode or self.lockmode
- if not reload and not self.always_refresh and lockmode is None:
+ lockmode = lockmode or self._lockmode
+ if not reload and not self.mapper.always_refresh and lockmode is None:
try:
return self.session._get(key)
except KeyError:
@@ -1062,21 +703,22 @@ class Query(object):
else:
ident = util.to_list(ident)
params = {}
- try:
- for i, primary_key in enumerate(self.primary_key_columns):
+
+ for i, primary_key in enumerate(self.primary_key_columns):
+ try:
params[primary_key._label] = ident[i]
- except IndexError:
- raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
+ except IndexError:
+ raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
try:
- statement = self.compile(self._get_clause, lockmode=lockmode)
- return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0]
+ q = self
+ if lockmode is not None:
+ q = q.with_lockmode(lockmode)
+ q = q.filter(self.select_mapper._get_clause)
+ q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
+ return q.first()
except IndexError:
return None
- def _select_statement(self, statement, params=None, **kwargs):
- statement.use_labels = True
- return self.execute(statement, params=params, **kwargs)
-
def _should_nest(self, querycontext):
"""Return True if the given statement options indicate that we
should *nest* the generated query as a subquery inside of a
@@ -1094,21 +736,56 @@ class Query(object):
return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
- def compile(self, whereclause = None, **kwargs):
- """Given a WHERE criterion, produce a ClauseElement-based
- statement suitable for usage in the execute() method.
+ def count(self, whereclause=None, params=None, **kwargs):
+ """Apply this query's criterion to a SELECT COUNT statement.
+
+ the whereclause, params and \**kwargs arguments are deprecated. use filter()
+ and other generative methods to establish modifiers.
+ """
+
+ q = self
+ if whereclause is not None:
+ q = q.filter(whereclause)
+ if params is not None:
+ q = q.params(**params)
+ q = q._legacy_select_kwargs(**kwargs)
+ return q._count()
- the arguments to this function are deprecated and are removed in version 0.4.
+ def _count(self):
+ """Apply this query's criterion to a SELECT COUNT statement.
+
+ this is the purely generative version which will become
+ the public method in version 0.5.
"""
+ whereclause = self._criterion
+
+ context = QueryContext(self)
+ from_obj = context.from_obj
+
+ alltables = []
+ for l in [sql_util.TableFinder(x) for x in from_obj]:
+ alltables += l
+
+ if self.table not in alltables:
+ from_obj.append(self.table)
+ if self._nestable(**context.select_args()):
+ s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count()
+ else:
+ primary_key = self.primary_key_columns
+ s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args())
+ return self.session.scalar(s, params=self._params, mapper=self.mapper)
+
+ def compile(self):
+ """compiles and returns a SQL statement based on the criterion and conditions within this Query."""
+
if self._statement:
self._statement.use_labels = True
return self._statement
+
+ whereclause = self._criterion
- if self._criterion:
- whereclause = sql.and_(self._criterion, whereclause)
-
- if whereclause is not None and self.is_polymorphic:
+ if whereclause is not None and (self.mapper is not self.select_mapper):
# adapt the given WHERECLAUSE to adjust instances of this query's mapped
# table to be that of our select_table,
# which may be the "polymorphic" selectable used by our mapper.
@@ -1124,16 +801,10 @@ class Query(object):
# get/create query context. get the ultimate compile arguments
# from there
- context = kwargs.pop('query_context', None)
- if context is None:
- context = QueryContext(self, kwargs)
+ context = QueryContext(self)
order_by = context.order_by
- group_by = context.group_by
from_obj = context.from_obj
lockmode = context.lockmode
- distinct = context.distinct
- limit = context.limit
- offset = context.offset
if order_by is False:
order_by = self.mapper.order_by
if order_by is False:
@@ -1161,31 +832,33 @@ class Query(object):
# if theres an order by, add those columns to the column list
# of the "rowcount" query we're going to make
if order_by:
- order_by = util.to_list(order_by) or []
+ order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []]
cf = sql_util.ColumnFinder()
for o in order_by:
cf.traverse(o)
else:
cf = []
- s2 = sql.select(self.table.primary_key + list(cf), whereclause, use_labels=True, from_obj=from_obj, **context.select_args())
+ s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args())
if order_by:
- s2.order_by(*util.to_list(order_by))
+ s2 = s2.order_by(*util.to_list(order_by))
s3 = s2.alias('tbl_row_count')
- crit = s3.primary_key==self.table.primary_key
+ crit = s3.primary_key==self.primary_key_columns
statement = sql.select([], crit, use_labels=True, for_update=for_update)
# now for the order by, convert the columns to their corresponding columns
# in the "rowcount" query, and tack that new order by onto the "rowcount" query
if order_by:
- statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
+ statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
else:
statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
if order_by:
- statement.order_by(*util.to_list(order_by))
+ statement.append_order_by(*util.to_list(order_by))
+
# for a DISTINCT query, you need the columns explicitly specified in order
# to use it in "order_by". ensure they are in the column criterion (particularly oid).
# TODO: this should be done at the SQL level not the mapper level
- if kwargs.get('distinct', False) and order_by:
+ # TODO: need test coverage for this
+ if context.distinct and order_by:
[statement.append_column(c) for c in util.to_list(order_by)]
context.statement = statement
@@ -1197,20 +870,268 @@ class Query(object):
value.setup(context)
# additional entities/columns, add those to selection criterion
- for m in self._entities:
- if isinstance(m, type):
- m = mapper.class_mapper(m)
+ for tup in self._entities:
+ (m, alias, alias_id) = tup
+ clauses = self._get_entity_clauses(tup)
if isinstance(m, mapper.Mapper):
for value in m.iterate_properties:
- value.setup(context)
+ value.setup(context, parentclauses=clauses)
elif isinstance(m, sql.ColumnElement):
+ if clauses is not None:
+ m = clauses.adapt_clause(m)
statement.append_column(m)
return statement
+ def _get_entity_clauses(self, m):
+ """for tuples added via add_entity() or add_column(), attempt to locate
+ an AliasedClauses object which should be used to formulate the query as well
+ as to process result rows."""
+ (m, alias, alias_id) = m
+ if alias is not None:
+ return alias
+ if alias_id is not None:
+ try:
+ return self._alias_ids[alias_id]
+ except KeyError:
+ raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id)
+ if isinstance(m, type):
+ m = mapper.class_mapper(m)
+ if isinstance(m, mapper.Mapper):
+ l = self._alias_ids.get(m)
+ if l:
+ if len(l) > 1:
+ raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(m))
+ else:
+ return l[0]
+ else:
+ return None
+ elif isinstance(m, sql.ColumnElement):
+ aliases = []
+ for table in sql_util.TableFinder(m, check_columns=True):
+ for a in self._alias_ids.get(table, []):
+ aliases.append(a)
+ if len(aliases) > 1:
+ raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column()" % str(m))
+ elif len(aliases) == 1:
+ return aliases[0]
+ else:
+ return None
+
def __log_debug(self, msg):
self.logger.debug(msg)
+ def __str__(self):
+ return str(self.compile())
+
+ # DEPRECATED LAND !
+
+ def list(self):
+ """DEPRECATED. use all()"""
+
+ return list(self)
+
+ def scalar(self):
+ """DEPRECATED. use first()"""
+
+ return self.first()
+
+ def _legacy_filter_by(self, *args, **kwargs):
+ return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint))
+
+ def count_by(self, *args, **params):
+ """DEPRECATED. use query.filter_by(\**params).count()"""
+
+ return self.count(self.join_by(*args, **params))
+
+
+ def select_whereclause(self, whereclause=None, params=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).all()"""
+
+ q = self.filter(whereclause)._legacy_select_kwargs(**kwargs)
+ if params is not None:
+ q = q.params(**params)
+ return list(q)
+
+ def _legacy_select_kwargs(self, **kwargs):
+ q = self
+ if "order_by" in kwargs and kwargs['order_by']:
+ q = q.order_by(kwargs['order_by'])
+ if "group_by" in kwargs:
+ q = q.group_by(kwargs['group_by'])
+ if "from_obj" in kwargs:
+ q = q.select_from(kwargs['from_obj'])
+ if "lockmode" in kwargs:
+ q = q.with_lockmode(kwargs['lockmode'])
+ if "distinct" in kwargs:
+ q = q.distinct()
+ if "limit" in kwargs:
+ q = q.limit(kwargs['limit'])
+ if "offset" in kwargs:
+ q = q.offset(kwargs['offset'])
+ return q
+
+
+ def get_by(self, *args, **params):
+ """DEPRECATED. use query.filter_by(\**params).first()"""
+
+ ret = self._extension.get_by(self, *args, **params)
+ if ret is not mapper.EXT_PASS:
+ return ret
+
+ return self._legacy_filter_by(*args, **params).first()
+
+ def select_by(self, *args, **params):
+ """DEPRECATED. use use query.filter_by(\**params).all()."""
+
+ ret = self._extension.select_by(self, *args, **params)
+ if ret is not mapper.EXT_PASS:
+ return ret
+
+ return self._legacy_filter_by(*args, **params).list()
+
+ def join_by(self, *args, **params):
+ """DEPRECATED. use join() to construct joins based on attribute names."""
+
+ return self._legacy_join_by(args, params, start=self._joinpoint)
+
+ def _build_select(self, arg=None, params=None, **kwargs):
+ if isinstance(arg, sql.FromClause) and arg.supports_execution():
+ return self.from_statement(arg)
+ else:
+ return self.filter(arg)._legacy_select_kwargs(**kwargs)
+
+ def selectfirst(self, arg=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).first()"""
+
+ return self._build_select(arg, **kwargs).first()
+
+ def selectone(self, arg=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).one()"""
+
+ return self._build_select(arg, **kwargs).one()
+
+ def select(self, arg=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
+
+ ret = self._extension.select(self, arg=arg, **kwargs)
+ if ret is not mapper.EXT_PASS:
+ return ret
+ return self._build_select(arg, **kwargs).all()
+
+ def execute(self, clauseelement, params=None, *args, **kwargs):
+ """DEPRECATED. use query.from_statement().all()"""
+
+ return self._select_statement(clauseelement, params, **kwargs)
+
+ def select_statement(self, statement, **params):
+ """DEPRECATED. Use query.from_statement(statement)"""
+
+ return self._select_statement(statement, params)
+
+ def select_text(self, text, **params):
+ """DEPRECATED. Use query.from_statement(statement)"""
+
+ return self._select_statement(text, params)
+
+ def _select_statement(self, statement, params=None, **kwargs):
+ q = self.from_statement(statement)
+ if params is not None:
+ q = q.params(**params)
+ q._select_context_options(**kwargs)
+ return list(q)
+
+ def _select_context_options(self, populate_existing=None, version_check=None):
+ if populate_existing is not None:
+ self._populate_existing = populate_existing
+ if version_check is not None:
+ self._version_check = version_check
+ return self
+
+ def join_to(self, key):
+ """DEPRECATED. use join() to create joins based on property names."""
+
+ [keys, p] = self._locate_prop(key)
+ return self.join_via(keys)
+
+ def join_via(self, keys):
+ """DEPRECATED. use join() to create joins based on property names."""
+
+ mapper = self._joinpoint
+ clause = None
+ for key in keys:
+ prop = mapper.get_property(key, resolve_synonyms=True)
+ if clause is None:
+ clause = prop.get_join(mapper)
+ else:
+ clause &= prop.get_join(mapper)
+ mapper = prop.mapper
+
+ return clause
+
+ def _legacy_join_by(self, args, params, start=None):
+ import properties
+
+ clause = None
+ for arg in args:
+ if clause is None:
+ clause = arg
+ else:
+ clause &= arg
+
+ for key, value in params.iteritems():
+ (keys, prop) = self._locate_prop(key, start=start)
+ if isinstance(prop, properties.PropertyLoader):
+ c = prop.compare(operator.eq, value) & self.join_via(keys[:-1])
+ else:
+ c = prop.compare(operator.eq, value) & self.join_via(keys)
+ if clause is None:
+ clause = c
+ else:
+ clause &= c
+ return clause
+
+ def _locate_prop(self, key, start=None):
+ import properties
+ keys = []
+ seen = util.Set()
+ def search_for_prop(mapper_):
+ if mapper_ in seen:
+ return None
+ seen.add(mapper_)
+
+ prop = mapper_.get_property(key, resolve_synonyms=True, raiseerr=False)
+ if prop is not None:
+ if isinstance(prop, properties.PropertyLoader):
+ keys.insert(0, prop.key)
+ return prop
+ else:
+ for prop in mapper_.iterate_properties:
+ if not isinstance(prop, properties.PropertyLoader):
+ continue
+ x = search_for_prop(prop.mapper)
+ if x:
+ keys.insert(0, prop.key)
+ return x
+ else:
+ return None
+ p = search_for_prop(start or self.mapper)
+ if p is None:
+ raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
+ return [keys, p]
+
+ def selectfirst_by(self, *args, **params):
+ """DEPRECATED. Use query.filter_by(\**kwargs).first()"""
+
+ return self._legacy_filter_by(*args, **params).first()
+
+ def selectone_by(self, *args, **params):
+ """DEPRECATED. Use query.filter_by(\**kwargs).one()"""
+
+ return self._legacy_filter_by(*args, **params).one()
+
+
+
Query.logger = logging.class_logger(Query)
class QueryContext(OperationContext):
@@ -1219,25 +1140,25 @@ class QueryContext(OperationContext):
in a query construction.
"""
- def __init__(self, query, kwargs):
+ def __init__(self, query):
self.query = query
- self.order_by = kwargs.pop('order_by', query._order_by)
- self.group_by = kwargs.pop('group_by', query._group_by)
- self.from_obj = kwargs.pop('from_obj', query._from_obj)
- self.lockmode = kwargs.pop('lockmode', query.lockmode)
- self.distinct = kwargs.pop('distinct', query._distinct)
- self.limit = kwargs.pop('limit', query._limit)
- self.offset = kwargs.pop('offset', query._offset)
+ self.order_by = query._order_by
+ self.group_by = query._group_by
+ self.from_obj = query._from_obj
+ self.lockmode = query._lockmode
+ self.distinct = query._distinct
+ self.limit = query._limit
+ self.offset = query._offset
self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
self.statement = None
- super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs)
+ super(QueryContext, self).__init__(query.mapper, query._with_options)
def select_args(self):
"""Return a dictionary of attributes from this
``QueryContext`` that can be applied to a ``sql.Select``
statement.
"""
- return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by}
+ return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None}
def accept_option(self, opt):
"""Accept a ``MapperOption`` which will process (modify) the
@@ -1265,8 +1186,10 @@ class SelectionContext(OperationContext):
yet been added as persistent to the Session.
attributes
- A dictionary to store arbitrary data; eager loaders use it to
- store additional result lists.
+ A dictionary to store arbitrary data; mappers, strategies, and
+ options all store various state information here in order
+ to communicate with each other and to themselves.
+
populate_existing
Indicates if its OK to overwrite the attributes of instances
@@ -1284,6 +1207,7 @@ class SelectionContext(OperationContext):
self.session = session
self.extension = extension
self.identity_map = {}
+ self.stack = LoaderStack()
super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs)
def accept_option(self, opt):
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 4e7453d84..6b5c4a072 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -4,12 +4,12 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import weakref
+
from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query
+from sqlalchemy.orm import unitofwork, query, util as mapperutil
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
-import weakref
-import sqlalchemy
class SessionTransaction(object):
"""Represents a Session-level Transaction.
@@ -21,70 +21,95 @@ class SessionTransaction(object):
The SessionTransaction object is **not** threadsafe.
"""
- def __init__(self, session, parent=None, autoflush=True):
+ def __init__(self, session, parent=None, autoflush=True, nested=False):
self.session = session
- self.connections = {}
- self.parent = parent
+ self.__connections = {}
+ self.__parent = parent
self.autoflush = autoflush
+ self.nested = nested
- def connection(self, mapper_or_class, entity_name=None):
+ def connection(self, mapper_or_class, entity_name=None, **kwargs):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- engine = self.session.get_bind(mapper_or_class)
+ engine = self.session.get_bind(mapper_or_class, **kwargs)
return self.get_or_add(engine)
- def _begin(self):
- return SessionTransaction(self.session, self)
+ def _begin(self, **kwargs):
+ return SessionTransaction(self.session, self, **kwargs)
def add(self, bind):
- if self.parent is not None:
- return self.parent.add(bind)
-
- if self.connections.has_key(bind.engine):
- raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
+ if self.__parent is not None:
+ return self.__parent.add(bind)
+ if self.__connections.has_key(bind.engine):
+ raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
return self.get_or_add(bind)
+ def _connection_dict(self):
+ if self.__parent is not None and not self.nested:
+ return self.__parent._connection_dict()
+ else:
+ return self.__connections
+
def get_or_add(self, bind):
- if self.parent is not None:
- return self.parent.get_or_add(bind)
+ if self.__parent is not None:
+ if not self.nested:
+ return self.__parent.get_or_add(bind)
+
+ if self.__connections.has_key(bind):
+ return self.__connections[bind][0]
+
+ if bind in self.__parent._connection_dict():
+ (conn, trans, autoclose) = self.__parent.__connections[bind]
+ self.__connections[conn] = self.__connections[bind.engine] = (conn, conn.begin_nested(), autoclose)
+ return conn
+ elif self.__connections.has_key(bind):
+ return self.__connections[bind][0]
- if self.connections.has_key(bind):
- return self.connections[bind][0]
-
if not isinstance(bind, engine.Connection):
e = bind
c = bind.contextual_connect()
else:
e = bind.engine
c = bind
- if e in self.connections:
+ if e in self.__connections:
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
-
- self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind)
- return self.connections[bind][0]
+ if self.nested:
+ trans = c.begin_nested()
+ elif self.session.twophase:
+ trans = c.begin_twophase()
+ else:
+ trans = c.begin()
+ self.__connections[c] = self.__connections[e] = (c, trans, c is not bind)
+ return self.__connections[c][0]
def commit(self):
- if self.parent is not None:
- return
+ if self.__parent is not None and not self.nested:
+ return self.__parent
if self.autoflush:
self.session.flush()
- for t in util.Set(self.connections.values()):
+
+ if self.session.twophase:
+ for t in util.Set(self.__connections.values()):
+ t[1].prepare()
+
+ for t in util.Set(self.__connections.values()):
t[1].commit()
self.close()
+ return self.__parent
def rollback(self):
- if self.parent is not None:
- self.parent.rollback()
- return
- for k, t in self.connections.iteritems():
+ if self.__parent is not None and not self.nested:
+ return self.__parent.rollback()
+ for t in util.Set(self.__connections.values()):
t[1].rollback()
self.close()
-
+ return self.__parent
+
def close(self):
- if self.parent is not None:
+ if self.__parent is not None:
return
- for t in self.connections.values():
+ for t in util.Set(self.__connections.values()):
if t[2]:
t[0].close()
self.session.transaction = None
@@ -108,23 +133,24 @@ class Session(object):
of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module.
"""
- def __init__(self, bind=None, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False):
- if import_session is not None:
- self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map)
- else:
- self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
+ def __init__(self, bind=None, autoflush=False, transactional=False, twophase=False, echo_uow=False, weak_identity_map=False):
+ self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
- self.bind = bind or bind_to
- self.binds = {}
+ self.bind = bind
+ self.__binds = {}
self.echo_uow = echo_uow
self.weak_identity_map = weak_identity_map
self.transaction = None
- if hash_key is None:
- self.hash_key = id(self)
- else:
- self.hash_key = hash_key
+ self.hash_key = id(self)
+ self.autoflush = autoflush
+ self.transactional = transactional or autoflush
+ self.twophase = twophase
+ self._query_cls = query.Query
+ self._mapper_flush_opts = {}
+ if self.transactional:
+ self.begin()
_sessions[self.hash_key] = self
-
+
def _get_echo_uow(self):
return self.uow.echo
@@ -132,37 +158,39 @@ class Session(object):
self.uow.echo = value
echo_uow = property(_get_echo_uow,_set_echo_uow)
- bind_to = property(lambda self:self.bind)
-
- def create_transaction(self, **kwargs):
- """Return a new ``SessionTransaction`` corresponding to an
- existing or new transaction.
-
- If the transaction is new, the returned ``SessionTransaction``
- will have commit control over the underlying transaction, else
- will have rollback control only.
- """
+ def begin(self, **kwargs):
+ """Begin a transaction on this Session."""
if self.transaction is not None:
- return self.transaction._begin()
+ self.transaction = self.transaction._begin(**kwargs)
else:
self.transaction = SessionTransaction(self, **kwargs)
- return self.transaction
-
- def connect(self, mapper=None, **kwargs):
- """Return a unique connection corresponding to the given mapper.
-
- This connection will not be part of any pre-existing
- transactional context.
- """
-
- return self.get_bind(mapper).connect(**kwargs)
-
- def connection(self, mapper, **kwargs):
- """Return a ``Connection`` corresponding to the given mapper.
+ return self.transaction
+
+ create_transaction = begin
- Used by the ``execute()`` method which performs select
- operations for ``Mapper`` and ``Query``.
+ def begin_nested(self):
+ return self.begin(nested=True)
+
+ def rollback(self):
+ if self.transaction is None:
+ raise exceptions.InvalidRequestError("No transaction is begun.")
+ else:
+ self.transaction = self.transaction.rollback()
+ if self.transaction is None and self.transactional:
+ self.begin()
+
+ def commit(self):
+ if self.transaction is None:
+ raise exceptions.InvalidRequestError("No transaction is begun.")
+ else:
+ self.transaction = self.transaction.commit()
+ if self.transaction is None and self.transactional:
+ self.begin()
+
+ def connection(self, mapper=None, **kwargs):
+ """Return a ``Connection`` corresponding to this session's
+ transactional context, if any.
If this ``Session`` is transactional, the connection will be in
the context of this session's transaction. Otherwise, the
@@ -173,6 +201,9 @@ class Session(object):
The given `**kwargs` will be sent to the engine's
``contextual_connect()`` method, if no transaction is in
progress.
+
+ the "mapper" argument is a class or mapper to which a bound engine
+ will be located; use this when the Session itself is unbound.
"""
if self.transaction is not None:
@@ -180,7 +211,7 @@ class Session(object):
else:
return self.get_bind(mapper).contextual_connect(**kwargs)
- def execute(self, mapper, clause, params, **kwargs):
+ def execute(self, clause, params=None, mapper=None, **kwargs):
"""Using the given mapper to identify the appropriate ``Engine``
or ``Connection`` to be used for statement execution, execute the
given ``ClauseElement`` using the provided parameter dictionary.
@@ -191,12 +222,12 @@ class Session(object):
then the ``ResultProxy`` 's ``close()`` method will release the
resources of the underlying ``Connection``, otherwise its a no-op.
"""
- return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs)
+ return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs)
- def scalar(self, mapper, clause, params, **kwargs):
+ def scalar(self, clause, params=None, mapper=None, **kwargs):
"""Like execute() but return a scalar result."""
- return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs)
+ return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs)
def close(self):
"""Close this Session."""
@@ -224,14 +255,17 @@ class Session(object):
return _class_mapper(class_, entity_name = entity_name)
- def bind_mapper(self, mapper, bind):
- """Bind the given `mapper` to the given ``Engine`` or ``Connection``.
+ def bind_mapper(self, mapper, bind, entity_name=None):
+ """Bind the given `mapper` or `class` to the given ``Engine`` or ``Connection``.
All subsequent operations involving this ``Mapper`` will use the
given `bind`.
"""
+
+ if isinstance(mapper, type):
+ mapper = _class_mapper(mapper, entity_name=entity_name)
- self.binds[mapper] = bind
+ self.__binds[mapper] = bind
def bind_table(self, table, bind):
"""Bind the given `table` to the given ``Engine`` or ``Connection``.
@@ -240,7 +274,7 @@ class Session(object):
given `bind`.
"""
- self.binds[table] = bind
+ self.__binds[table] = bind
def get_bind(self, mapper):
"""Return the ``Engine`` or ``Connection`` which is used to execute
@@ -270,15 +304,18 @@ class Session(object):
"""
if mapper is None:
- return self.bind
- elif self.binds.has_key(mapper):
- return self.binds[mapper]
- elif self.binds.has_key(mapper.mapped_table):
- return self.binds[mapper.mapped_table]
+ if self.bind is not None:
+ return self.bind
+ else:
+ raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
+ elif self.__binds.has_key(mapper):
+ return self.__binds[mapper]
+ elif self.__binds.has_key(mapper.mapped_table):
+ return self.__binds[mapper.mapped_table]
elif self.bind is not None:
return self.bind
else:
- e = mapper.mapped_table.engine
+ e = mapper.mapped_table.bind
if e is None:
raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
return e
@@ -291,9 +328,9 @@ class Session(object):
entity_name = kwargs.pop('entity_name', None)
if isinstance(mapper_or_class, type):
- q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
+ q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
else:
- q = query.Query(mapper_or_class, self, **kwargs)
+ q = self._query_cls(mapper_or_class, self, **kwargs)
for ent in addtl_entities:
q = q.add_entity(ent)
@@ -499,7 +536,7 @@ class Session(object):
merged = self.get(mapper.class_, key[1])
if merged is None:
raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object))
- for prop in mapper.props.values():
+ for prop in mapper.iterate_properties:
prop.merge(self, object, merged, _recursive)
if key is None:
self.save(merged, entity_name=mapper.entity_name)
@@ -611,12 +648,12 @@ class Session(object):
def _attach(self, obj):
"""Attach the given object to this ``Session``."""
- if getattr(obj, '_sa_session_id', None) != self.hash_key:
- old = getattr(obj, '_sa_session_id', None)
- if old is not None and _sessions.has_key(old):
+ old_id = getattr(obj, '_sa_session_id', None)
+ if old_id != self.hash_key:
+ if old_id is not None and _sessions.has_key(old_id):
raise exceptions.InvalidRequestError("Object '%s' is already attached "
"to session '%s' (this is '%s')" %
- (repr(obj), old, id(self)))
+ (repr(obj), old_id, id(self)))
# auto-removal from the old session is disabled. but if we decide to
# turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
@@ -695,6 +732,7 @@ def object_session(obj):
return _sessions.get(hashkey)
return None
+# Lazy initialization to avoid circular imports
unitofwork.object_session = object_session
from sqlalchemy.orm import mapper
mapper.attribute_manager = attribute_manager
diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py
new file mode 100644
index 000000000..cc13f8c1f
--- /dev/null
+++ b/lib/sqlalchemy/orm/shard.py
@@ -0,0 +1,112 @@
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm import Query
+
+class ShardedSession(Session):
+ def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs):
+ """construct a ShardedSession.
+
+ shard_chooser
+ a callable which, passed a Mapper and a mapped instance, returns a
+ shard ID. this id may be based off of the attributes present within the
+ object, or on some round-robin scheme. If the scheme is based on a
+ selection, it should set whatever state on the instance to mark it in
+ the future as participating in that shard.
+
+ id_chooser
+ a callable, passed a tuple of identity values, which should return
+ a list of shard ids where the ID might reside. The databases will
+ be queried in the order of this listing.
+
+ query_chooser
+ for a given Query, returns the list of shard_ids where the query
+ should be issued. Results from all shards returned will be
+ combined together into a single listing.
+
+ """
+ super(ShardedSession, self).__init__(**kwargs)
+ self.shard_chooser = shard_chooser
+ self.id_chooser = id_chooser
+ self.query_chooser = query_chooser
+ self.__binds = {}
+ self._mapper_flush_opts = {'connection_callable':self.connection}
+ self._query_cls = ShardedQuery
+
+ def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance)
+
+ if self.transaction is not None:
+ return self.transaction.connection(mapper, shard_id=shard_id)
+ else:
+ return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs)
+
+ def get_bind(self, mapper, shard_id=None, instance=None):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance)
+ return self.__binds[shard_id]
+
+ def bind_shard(self, shard_id, bind):
+ self.__binds[shard_id] = bind
+
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self._shard_id = None
+
+ def _clone(self):
+ q = ShardedQuery.__new__(ShardedQuery)
+ q.__dict__ = self.__dict__.copy()
+ return q
+
+ def set_shard(self, shard_id):
+ """return a new query, limited to a single shard ID.
+
+ all subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+ """
+
+ q = self._clone()
+ q._shard_id = shard_id
+ return q
+
+ def _execute_and_instances(self, statement):
+ if self._shard_id is not None:
+ result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(statement, **self._params)
+ try:
+ return iter(self.instances(result))
+ finally:
+ result.close()
+ else:
+ partial = []
+ for shard_id in self.query_chooser(self):
+ result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(statement, **self._params)
+ try:
+ partial = partial + list(self.instances(result))
+ finally:
+ result.close()
+ # if some kind of in memory 'sorting' were done, this is where it would happen
+ return iter(partial)
+
+ def get(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).get(ident)
+ else:
+ for shard_id in self.id_chooser(ident):
+ o = self.set_shard(shard_id).get(ident, **kwargs)
+ if o is not None:
+ return o
+ else:
+ return None
+
+ def load(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).load(ident)
+ else:
+ for shard_id in self.id_chooser(ident):
+ o = self.set_shard(shard_id).load(ident, raiseerr=False, **kwargs)
+ if o is not None:
+ return o
+ else:
+ raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 462954f6b..babd6e4c0 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -6,12 +6,11 @@
"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
-from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
-from sqlalchemy.orm import mapper, query
-from sqlalchemy.orm.interfaces import *
+from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy.orm import mapper, attributes
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import random
class ColumnLoader(LoaderStrategy):
@@ -19,8 +18,9 @@ class ColumnLoader(LoaderStrategy):
super(ColumnLoader, self).init()
self.columns = self.parent_property.columns
self._should_log_debug = logging.is_debug_enabled(self.logger)
+ self.is_composite = hasattr(self.parent_property, 'composite_class')
- def setup_query(self, context, eagertable=None, parentclauses=None, **kwargs):
+ def setup_query(self, context, parentclauses=None, **kwargs):
for c in self.columns:
if parentclauses is not None:
context.statement.append_column(parentclauses.aliased_column(c))
@@ -28,16 +28,93 @@ class ColumnLoader(LoaderStrategy):
context.statement.append_column(c)
def init_class_attribute(self):
+ if self.is_composite:
+ self._init_composite_attribute()
+ else:
+ self._init_scalar_attribute()
+
+ def _init_composite_attribute(self):
+ self.logger.info("register managed composite attribute %s on class %s" % (self.key, self.parent.class_.__name__))
+ def copy(obj):
+ return self.parent_property.composite_class(*obj.__colset__())
+ def compare(a, b):
+ for col, aprop, bprop in zip(self.columns, a.__colset__(), b.__colset__()):
+ if not col.type.compare_values(aprop, bprop):
+ return False
+ else:
+ return True
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator)
+
+ def _init_scalar_attribute(self):
self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
coltype = self.columns[0].type
- sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable())
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
+
+ def create_row_processor(self, selectcontext, mapper, row):
+ if self.is_composite:
+ for c in self.columns:
+ if c not in row:
+ break
+ else:
+ def execute(instance, row, isnew, ispostselect=None, **flags):
+ if isnew or ispostselect:
+ if self._should_log_debug:
+ self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
+ instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns])
+ self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key))
+ return (execute, None)
+
+ elif self.columns[0] in row:
+ def execute(instance, row, isnew, ispostselect=None, **flags):
+ if isnew or ispostselect:
+ if self._should_log_debug:
+ self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
+ instance.__dict__[self.key] = row[self.columns[0]]
+ self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
+ return (execute, None)
+
+ (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None))
+ if hosted_mapper is None:
+ return (None, None)
+
+ if hosted_mapper.polymorphic_fetch == 'deferred':
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_loader(instance, mapper, needs_tables))
+ self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
+ return (execute, None)
+ else:
+ self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
+ return (None, None)
+
+ def _get_deferred_loader(self, instance, mapper, needs_tables):
+ def load():
+ group = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
if self._should_log_debug:
- self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
- instance.__dict__[self.key] = row[self.columns[0]]
-
+ self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None'))
+
+ session = sessionlib.object_session(instance)
+ if session is None:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+ cond, param_names = mapper._deferred_inheritance_condition(needs_tables)
+ statement = sql.select(needs_tables, cond, use_labels=True)
+ params = {}
+ for c in param_names:
+ params[c.name] = mapper.get_attr_by_column(instance, c)
+
+ result = session.execute(statement, params, mapper=mapper)
+ try:
+ row = result.fetchone()
+ for prop in group:
+ sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
+ return attributes.ATTR_WAS_SET
+ finally:
+ result.close()
+
+ return load
+
ColumnLoader.logger = logging.class_logger(ColumnLoader)
class DeferredColumnLoader(LoaderStrategy):
@@ -47,74 +124,86 @@ class DeferredColumnLoader(LoaderStrategy):
This is per-column lazy loading.
"""
+ def create_row_processor(self, selectcontext, mapper, row):
+ if self.group is not None and selectcontext.attributes.get(('undefer', self.group), False):
+ return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row)
+ elif not self.is_default or len(selectcontext.options):
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
+ sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance))
+ return (execute, None)
+ else:
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
+ sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+ return (execute, None)
+
def init(self):
super(DeferredColumnLoader, self).init()
+ if hasattr(self.parent_property, 'composite_class'):
+ raise NotImplementedError("Deferred loading for composite types not implemented yet")
self.columns = self.parent_property.columns
self.group = self.parent_property.group
self._should_log_debug = logging.is_debug_enabled(self.logger)
def init_class_attribute(self):
self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
- sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=lambda i:self.setup_loader(i), copy_function=lambda x: self.columns[0].type.copy_value(x), compare_function=lambda x,y:self.columns[0].type.compare_values(x,y), mutable_scalars=self.columns[0].type.is_mutable())
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
def setup_query(self, context, **kwargs):
- pass
+ if self.group is not None and context.attributes.get(('undefer', self.group), False):
+ self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
- if not self.is_default or len(selectcontext.options):
- sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self.setup_loader(instance, selectcontext.options))
- else:
- sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
-
- def setup_loader(self, instance, options=None):
- if not mapper.has_mapper(instance):
+ def setup_loader(self, instance):
+ localparent = mapper.object_mapper(instance, raiseerror=False)
+ if localparent is None:
return None
- else:
- prop = mapper.object_mapper(instance).props[self.key]
- if prop is not self.parent_property:
- return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
- def lazyload():
- if self._should_log_debug:
- self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), str(self.group)))
+
+ prop = localparent.get_property(self.key)
+ if prop is not self.parent_property:
+ return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
+ def lazyload():
if not mapper.has_identity(instance):
return None
- try:
- pk = self.parent.pks_by_table[self.columns[0].table]
- except KeyError:
- pk = self.columns[0].table.primary_key
-
- clause = sql.and_()
- for primary_key in pk:
- attr = self.parent.get_attr_by_column(instance, primary_key)
- if not attr:
- return None
- clause.clauses.append(primary_key == attr)
+ if self.group is not None:
+ group = [p for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
+ else:
+ group = None
+
+ if self._should_log_debug:
+ self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None'))
session = sessionlib.object_session(instance)
if session is None:
raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
- localparent = mapper.object_mapper(instance)
- if self.group is not None:
- groupcols = [p for p in localparent.props.values() if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
- result = session.execute(localparent, sql.select([g.columns[0] for g in groupcols], clause, use_labels=True), None)
+
+ clause = localparent._get_clause
+ ident = instance._instance_key[1]
+ params = {}
+ for i, primary_key in enumerate(localparent.primary_key):
+ params[primary_key._label] = ident[i]
+ if group is not None:
+ statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True)
+ else:
+ statement = sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True)
+
+ if group is not None:
+ result = session.execute(statement, params, mapper=localparent)
try:
row = result.fetchone()
- for prop in groupcols:
- if prop is self:
- continue
- # set a scalar object instance directly on the object,
- # bypassing SmartProperty event handlers.
- sessionlib.attribute_manager.init_instance_attribute(instance, prop.key, uselist=False)
- instance.__dict__[prop.key] = row[prop.columns[0]]
- return row[self.columns[0]]
+ for prop in group:
+ sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
+ return attributes.ATTR_WAS_SET
finally:
result.close()
else:
- return session.scalar(localparent, sql.select([self.columns[0]], clause, use_labels=True),None)
+ return session.scalar(sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True),params, mapper=localparent)
return lazyload
@@ -131,6 +220,15 @@ class DeferredOption(StrategizedOption):
else:
return ColumnLoader
+class UndeferGroupOption(MapperOption):
+ def __init__(self, group):
+ self.group = group
+ def process_query_context(self, context):
+ context.attributes[('undefer', self.group)] = True
+
+ def process_selection_context(self, context):
+ context.attributes[('undefer', self.group)] = True
+
class AbstractRelationLoader(LoaderStrategy):
def init(self):
super(AbstractRelationLoader, self).init()
@@ -139,22 +237,26 @@ class AbstractRelationLoader(LoaderStrategy):
self._should_log_debug = logging.is_debug_enabled(self.logger)
def _init_instance_attribute(self, instance, callable_=None):
- return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_)
+ return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_)
def _register_attribute(self, class_, callable_=None):
self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
- sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_)
+ sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator)
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):
self._register_attribute(self.parent.class_)
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
- if not self.is_default or len(selectcontext.options):
- if self._should_log_debug:
- self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key))
- self._init_instance_attribute(instance)
+ def create_row_processor(self, selectcontext, mapper, row):
+ if not self.is_default or len(selectcontext.options):
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key))
+ self._init_instance_attribute(instance)
+ return (execute, None)
+ else:
+ return (None, None)
NoLoader.logger = logging.class_logger(NoLoader)
@@ -167,7 +269,8 @@ class LazyLoader(AbstractRelationLoader):
# determine if our "lazywhere" clause is the same as the mapper's
# get() clause. then we can just use mapper.get()
- self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere)
+ #from sqlalchemy.orm import query
+ self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere)
if self.use_get:
self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
@@ -178,7 +281,7 @@ class LazyLoader(AbstractRelationLoader):
if not mapper.has_mapper(instance):
return None
else:
- prop = mapper.object_mapper(instance).props[self.key]
+ prop = mapper.object_mapper(instance).get_property(self.key)
if prop is not self.parent_property:
return prop._get_strategy(LazyLoader).setup_loader(instance)
def lazyload():
@@ -211,20 +314,27 @@ class LazyLoader(AbstractRelationLoader):
# if we have a simple straight-primary key load, use mapper.get()
# to possibly save a DB round trip
+ q = session.query(self.mapper)
if self.use_get:
ident = []
- for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]:
+ # TODO: when options are added to allow switching between union-based and non-union
+ # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper",
+ # probably via the query's own "mapper" property, and also use one of two "lazy" clauses,
+ # one against the "union" the other not
+ for primary_key in self.select_mapper.primary_key:
bind = self.lazyreverse[primary_key]
ident.append(params[bind.key])
- return session.query(self.mapper).get(ident)
+ return q.get(ident)
elif self.order_by is not False:
- order_by = self.order_by
+ q = q.order_by(self.order_by)
elif self.secondary is not None and self.secondary.default_order_by() is not None:
- order_by = self.secondary.default_order_by()
- else:
- order_by = False
- result = session.query(self.mapper, with_options=options).select_whereclause(self.lazywhere, order_by=order_by, params=params)
+ q = q.order_by(self.secondary.default_order_by())
+ if options:
+ q = q.options(*options)
+ q = q.filter(self.lazywhere).params(**params)
+
+ result = q.all()
if self.uselist:
return result
else:
@@ -232,25 +342,37 @@ class LazyLoader(AbstractRelationLoader):
return result[0]
else:
return None
+
+ if self.uselist:
+ return q.all()
+ else:
+ return q.first()
+
return lazyload
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
- # new object instance being loaded from a result row
- if not self.is_default or len(selectcontext.options):
- self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
- # which will override the clareset_instance_attributess-level behavior
- self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options))
- else:
- self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- # we are the primary manager for this attribute on this class - reset its per-instance attribute state,
- # so that the class-level lazy loader is executed when next referenced on this instance.
- # this usually is not needed unless the constructor of the object referenced the attribute before we got
- # to load data into it.
- sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
-
- def _create_lazy_clause(cls, prop, reverse_direction=False):
+ def create_row_processor(self, selectcontext, mapper, row):
+ if not self.is_default or len(selectcontext.options):
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
+ # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
+ # which will override the clareset_instance_attributess-level behavior
+ self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options))
+ return (execute, None)
+ else:
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
+ # we are the primary manager for this attribute on this class - reset its per-instance attribute state,
+ # so that the class-level lazy loader is executed when next referenced on this instance.
+ # this usually is not needed unless the constructor of the object referenced the attribute before we got
+ # to load data into it.
+ sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+ return (execute, None)
+
+ def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='):
(primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
binds = {}
@@ -272,19 +394,16 @@ class LazyLoader(AbstractRelationLoader):
FindColumnInColumnClause().traverse(expr)
return len(columns) and columns[0] or None
- def bind_label():
- # TODO: make this generation deterministic
- return "lazy_" + hex(random.randint(0, 65535))[2:]
-
def visit_binary(binary):
leftcol = find_column_in_expr(binary.left)
rightcol = find_column_in_expr(binary.right)
if leftcol is None or rightcol is None:
return
+
if should_bind(leftcol, rightcol):
col = leftcol
binary.left = binds.setdefault(leftcol,
- sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type, unique=True))
+ sql.bindparam(None, None, shortname=leftcol.name, type_=binary.right.type, unique=True))
reverse[rightcol] = binds[col]
# the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
@@ -292,21 +411,19 @@ class LazyLoader(AbstractRelationLoader):
if leftcol is not rightcol and should_bind(rightcol, leftcol):
col = rightcol
binary.right = binds.setdefault(rightcol,
- sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True))
+ sql.bindparam(None, None, shortname=rightcol.name, type_=binary.left.type, unique=True))
reverse[leftcol] = binds[col]
- lazywhere = primaryjoin.copy_container()
+ lazywhere = primaryjoin
li = mapperutil.BinaryVisitor(visit_binary)
if not secondaryjoin or not reverse_direction:
- li.traverse(lazywhere)
+ lazywhere = li.traverse(lazywhere, clone=True)
if secondaryjoin is not None:
- secondaryjoin = secondaryjoin.copy_container()
if reverse_direction:
- li.traverse(secondaryjoin)
+ secondaryjoin = li.traverse(secondaryjoin, clone=True)
lazywhere = sql.and_(lazywhere, secondaryjoin)
-
return (lazywhere, binds, reverse)
_create_lazy_clause = classmethod(_create_lazy_clause)
@@ -318,154 +435,42 @@ class EagerLoader(AbstractRelationLoader):
def init(self):
super(EagerLoader, self).init()
- if self.parent.isa(self.mapper):
- raise exceptions.ArgumentError(
- "Error creating eager relationship '%s' on parent class '%s' "
- "to child class '%s': Cant use eager loading on a self "
- "referential relationship." %
- (self.key, repr(self.parent.class_), repr(self.mapper.class_)))
if self.is_default:
self.parent._eager_loaders.add(self.parent_property)
self.clauses = {}
- self.clauses_by_lead_mapper = {}
-
- class AliasedClauses(object):
- """Defines a set of join conditions and table aliases which
- are aliased on a randomly-generated alias name, corresponding
- to the connection of an optional parent AliasedClauses object
- and a target mapper.
-
- EagerLoader has a distinct AliasedClauses object per parent
- AliasedClauses object, so that all paths from one mapper to
- another across a chain of eagerloaders generates a distinct
- chain of joins. The AliasedClauses objects are generated and
- cached on an as-needed basis.
-
- E.g.::
-
- mapper A -->
- (EagerLoader 'items') -->
- mapper B -->
- (EagerLoader 'keywords') -->
- mapper C
-
- will generate::
-
- EagerLoader 'items' --> {
- None : AliasedClauses(items, None, alias_suffix='AB34') # mappera JOIN mapperb_AB34
- }
-
- EagerLoader 'keywords' --> [
- None : AliasedClauses(keywords, None, alias_suffix='43EF') # mapperb JOIN mapperc_43EF
- AliasedClauses(items, None, alias_suffix='AB34') :
- AliasedClauses(keywords, items, alias_suffix='8F44') # mapperb_AB34 JOIN mapperc_8F44
- ]
- """
-
- def __init__(self, eagerloader, parentclauses=None):
- self.id = (parentclauses is not None and (parentclauses.id + "/") or '') + str(eagerloader.parent_property)
- self.parent = eagerloader
- self.target = eagerloader.select_table
- self.eagertarget = eagerloader.select_table.alias(self._aliashash("/target"))
- self.extra_cols = {}
-
- if eagerloader.secondary:
- self.eagersecondary = eagerloader.secondary.alias(self._aliashash("/secondary"))
- if parentclauses is not None:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget).\
- chain(sql_util.ClauseAdapter(self.eagersecondary)).\
- chain(sql_util.ClauseAdapter(parentclauses.eagertarget))
- else:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget).\
- chain(sql_util.ClauseAdapter(self.eagersecondary))
- self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container()
- aliasizer.traverse(self.eagersecondaryjoin)
- self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
- aliasizer.traverse(self.eagerprimary)
- else:
- self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
- if parentclauses is not None:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget)
- aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side))
- else:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget)
- aliasizer.traverse(self.eagerprimary)
-
- if eagerloader.order_by:
- self.eager_order_by = sql_util.ClauseAdapter(self.eagertarget).copy_and_process(util.to_list(eagerloader.order_by))
- else:
- self.eager_order_by = None
-
- self._row_decorator = self._create_decorator_row()
-
- def aliased_column(self, column):
- """return the aliased version of the given column, creating a new label for it if not already
- present in this AliasedClauses eagertable."""
-
- conv = self.eagertarget.corresponding_column(column, raiseerr=False)
- if conv:
- return conv
-
- if column in self.extra_cols:
- return self.extra_cols[column]
-
- aliased_column = column.copy_container()
- sql_util.ClauseAdapter(self.eagertarget).traverse(aliased_column)
- alias = self._aliashash(column.name)
- aliased_column = aliased_column.label(alias)
- self._row_decorator.map[column] = alias
- self.extra_cols[column] = aliased_column
- return aliased_column
-
- def _aliashash(self, extra):
- """return a deterministic 4 digit hash value for this AliasedClause's id + extra."""
- # use the first 4 digits of an MD5 hash
- return "anon_" + util.hash(self.id + extra)[0:4]
-
- def _create_decorator_row(self):
- class EagerRowAdapter(object):
- def __init__(self, row):
- self.row = row
- def has_key(self, key):
- return map.has_key(key) or self.row.has_key(key)
- def __getitem__(self, key):
- if map.has_key(key):
- key = map[key]
- return self.row[key]
- def keys(self):
- return map.keys()
- map = {}
- for c in self.eagertarget.c:
- parent = self.target.corresponding_column(c)
- map[parent] = c
- map[parent._label] = c
- map[parent.name] = c
- EagerRowAdapter.map = map
- return EagerRowAdapter
-
- def _decorate_row(self, row):
- # adapts a row at row iteration time to transparently
- # convert plain columns into the aliased columns that were actually
- # added to the column clause of the SELECT.
- return self._row_decorator(row)
+ self.join_depth = self.parent_property.join_depth
def init_class_attribute(self):
self.parent_property._get_strategy(LazyLoader).init_class_attribute()
- def setup_query(self, context, eagertable=None, parentclauses=None, parentmapper=None, **kwargs):
+ def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs):
"""Add a left outer join to the statement thats being constructed."""
+ # build a path as we setup the query. the format of this path
+ # matches that of interfaces.LoaderStack, and will be used in the
+ # row-loading phase to match up AliasedClause objects with the current
+ # LoaderStack position.
+ if parentclauses:
+ path = parentclauses.path + (self.parent.base_mapper(), self.key)
+ else:
+ path = (self.parent.base_mapper(), self.key)
+
+
+ if self.join_depth:
+ if len(path) / 2 > self.join_depth:
+ return
+ else:
+ if self.mapper in path:
+ return
+
+ #print "CREATING EAGER PATH FOR", "->".join([str(s) for s in path])
+
if parentmapper is None:
localparent = context.mapper
else:
localparent = parentmapper
- if self.mapper in context.recursion_stack:
- return
- else:
- context.recursion_stack.add(self.parent)
-
statement = context.statement
if hasattr(statement, '_outerjoin'):
@@ -487,55 +492,57 @@ class EagerLoader(AbstractRelationLoader):
break
else:
raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table))
-
+
try:
- clauses = self.clauses[parentclauses]
+ clauses = self.clauses[path]
except KeyError:
- clauses = EagerLoader.AliasedClauses(self, parentclauses)
- self.clauses[parentclauses] = clauses
-
- if context.mapper not in self.clauses_by_lead_mapper:
- self.clauses_by_lead_mapper[context.mapper] = clauses
-
+ clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.polymorphic_primaryjoin, self.parent_property.polymorphic_secondaryjoin, parentclauses)
+ self.clauses[path] = clauses
+
if self.secondaryjoin is not None:
- statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin)
+ statement._outerjoin = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin)
if self.order_by is False and self.secondary.default_order_by() is not None:
- statement.order_by(*clauses.eagersecondary.default_order_by())
+ statement.append_order_by(*clauses.secondary.default_order_by())
else:
- statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary)
- if self.order_by is False and clauses.eagertarget.default_order_by() is not None:
- statement.order_by(*clauses.eagertarget.default_order_by())
+ statement._outerjoin = towrap.outerjoin(clauses.alias, clauses.primaryjoin)
+ if self.order_by is False and clauses.alias.default_order_by() is not None:
+ statement.append_order_by(*clauses.alias.default_order_by())
- if clauses.eager_order_by:
- statement.order_by(*util.to_list(clauses.eager_order_by))
-
+ if clauses.order_by:
+ statement.append_order_by(*util.to_list(clauses.order_by))
+
statement.append_from(statement._outerjoin)
- for value in self.select_mapper.props.values():
- value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper)
- def _create_row_processor(self, selectcontext, row):
- """Create a *row processing* function that will apply eager
+ for value in self.select_mapper.iterate_properties:
+ value.setup(context, parentclauses=clauses, parentmapper=self.select_mapper)
+
+ def _create_row_decorator(self, selectcontext, row, path):
+ """Create a *row decorating* function that will apply eager
aliasing to the row.
Also check that an identity key can be retrieved from the row,
else return None.
"""
+ #print "creating row decorator for path ", "->".join([str(s) for s in path])
+
# check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option)
- if selectcontext.attributes.has_key((EagerLoader, self.parent_property)):
+ if selectcontext.attributes.has_key(("eager_row_processor", self.parent_property)):
# custom row decoration function, placed in the selectcontext by the
# contains_eager() mapper option
- decorator = selectcontext.attributes[(EagerLoader, self.parent_property)]
+ decorator = selectcontext.attributes[("eager_row_processor", self.parent_property)]
if decorator is None:
decorator = lambda row: row
else:
try:
# decorate the row according to the stored AliasedClauses for this eager load
- clauses = self.clauses_by_lead_mapper[selectcontext.mapper]
- decorator = clauses._row_decorator
+ clauses = self.clauses[path]
+ decorator = clauses.row_decorator
except KeyError, k:
# no stored AliasedClauses: eager loading was not set up in the query and
# AliasedClauses never got initialized
+ if self._should_log_debug:
+ self.logger.debug("Could not locate aliased clauses for key: " + str(path))
return None
try:
@@ -550,81 +557,80 @@ class EagerLoader(AbstractRelationLoader):
self.logger.debug("could not locate identity key from row '%s'; missing column '%s'" % (repr(decorated_row), str(k)))
return None
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- """Receive a row.
+ def create_row_processor(self, selectcontext, mapper, row):
+ selectcontext.stack.push_property(self.key)
+ path = selectcontext.stack.snapshot()
- Tell our mapper to look for a new object instance in the row,
- and attach it to a list on the parent instance.
- """
-
- if self in selectcontext.recursion_stack:
- return
-
- try:
- # check for row processor
- row_processor = selectcontext.attributes[id(self)]
- except KeyError:
- # create a row processor function and cache it in the context
- row_processor = self._create_row_processor(selectcontext, row)
- selectcontext.attributes[id(self)] = row_processor
-
- if row_processor is not None:
- decorated_row = row_processor(row)
- else:
- # row_processor was None: degrade to a lazy loader
- if self._should_log_debug:
- self.logger.debug("degrade to lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- self.parent_property._get_strategy(LazyLoader).process_row(selectcontext, instance, row, identitykey, isnew)
- return
-
- # TODO: recursion check a speed hit...? try to get a "termination point" into the AliasedClauses
- # or EagerRowAdapter ?
- selectcontext.recursion_stack.add(self)
- try:
- if not self.uselist:
- if self._should_log_debug:
- self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
- if isnew:
- # set a scalar object instance directly on the parent object,
- # bypassing SmartProperty event handlers.
- instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None)
+ row_decorator = self._create_row_decorator(selectcontext, row, path)
+ if row_decorator is not None:
+ def execute(instance, row, isnew, **flags):
+ decorated_row = row_decorator(row)
+
+ selectcontext.stack.push_property(self.key)
+
+ if not self.uselist:
+ if self._should_log_debug:
+ self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
+ if isnew:
+ # set a scalar object instance directly on the
+ # parent object, bypassing InstrumentedAttribute
+ # event handlers.
+ #
+ # FIXME: instead of...
+ sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None))
+ # bypass and set directly:
+ #instance.__dict__[self.key] = ...
+ else:
+ # call _instance on the row, even though the object has been created,
+ # so that we further descend into properties
+ self.mapper._instance(selectcontext, decorated_row, None)
else:
- # call _instance on the row, even though the object has been created,
- # so that we further descend into properties
- self.mapper._instance(selectcontext, decorated_row, None)
- else:
- if isnew:
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
+
+ collection = sessionlib.attribute_manager.init_collection(instance, self.key)
+ appender = util.UniqueAppender(collection, 'append_without_event')
+
+ # store it in the "scratch" area, which is local to this load operation.
+ selectcontext.attributes[(instance, self.key)] = appender
+ result_list = selectcontext.attributes[(instance, self.key)]
if self._should_log_debug:
- self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
- # call the SmartProperty's initialize() method to create a new, blank list
- l = getattr(instance.__class__, self.key).initialize(instance)
-
- # create an appender object which will add set-like semantics to the list
- appender = util.UniqueAppender(l.data)
-
- # store it in the "scratch" area, which is local to this load operation.
- selectcontext.attributes[(instance, self.key)] = appender
- result_list = selectcontext.attributes[(instance, self.key)]
- if self._should_log_debug:
- self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
- self.select_mapper._instance(selectcontext, decorated_row, result_list)
- finally:
- selectcontext.recursion_stack.remove(self)
+ self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
+
+ self.select_mapper._instance(selectcontext, decorated_row, result_list)
+ selectcontext.stack.pop()
+ selectcontext.stack.pop()
+ return (execute, None)
+ else:
+ self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
+ selectcontext.stack.pop()
+ return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
+
+
+ def __str__(self):
+ return str(self.parent) + "." + self.key
+
EagerLoader.logger = logging.class_logger(EagerLoader)
class EagerLazyOption(StrategizedOption):
- def __init__(self, key, lazy=True):
+ def __init__(self, key, lazy=True, chained=False):
super(EagerLazyOption, self).__init__(key)
self.lazy = lazy
-
- def process_query_property(self, context, prop):
+ self.chained = chained
+
+ def is_chained(self):
+ return not self.lazy and self.chained
+
+ def process_query_property(self, context, properties):
if self.lazy:
- if prop in context.eager_loaders:
- context.eager_loaders.remove(prop)
+ if properties[-1] in context.eager_loaders:
+ context.eager_loaders.remove(properties[-1])
else:
- context.eager_loaders.add(prop)
- super(EagerLazyOption, self).process_query_property(context, prop)
+ for prop in properties:
+ context.eager_loaders.add(prop)
+ super(EagerLazyOption, self).process_query_property(context, properties)
def get_strategy_class(self):
if self.lazy:
@@ -636,24 +642,39 @@ class EagerLazyOption(StrategizedOption):
EagerLazyOption.logger = logging.class_logger(EagerLazyOption)
+# TODO: enable FetchMode option. currently
+# this class does nothing. will require Query
+# to swich between using its "polymorphic" selectable
+# and its regular selectable in order to make decisions
+# (therefore might require that FetchModeOperation is performed
+# only as the first operation on a Query.)
+class FetchModeOption(PropertyOption):
+ def __init__(self, key, type):
+ super(FetchModeOption, self).__init__(key)
+ if type not in ('join', 'select'):
+ raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'")
+ self.type = type
+
+ def process_selection_property(self, context, properties):
+ context.attributes[('fetchmode', properties[-1])] = self.type
+
class RowDecorateOption(PropertyOption):
def __init__(self, key, decorator=None, alias=None):
super(RowDecorateOption, self).__init__(key)
self.decorator = decorator
self.alias = alias
- def process_selection_property(self, context, property):
+ def process_selection_property(self, context, properties):
if self.alias is not None and self.decorator is None:
if isinstance(self.alias, basestring):
- self.alias = property.target.alias(self.alias)
+ self.alias = properties[-1].target.alias(self.alias)
def decorate(row):
d = {}
- for c in property.target.columns:
+ for c in properties[-1].target.columns:
d[c] = row[self.alias.corresponding_column(c)]
return d
self.decorator = decorate
- context.attributes[(EagerLoader, property)] = self.decorator
+ context.attributes[("eager_row_processor", properties[-1])] = self.decorator
RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
-
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index 8c70f8cf8..cf48202b0 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -4,17 +4,16 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-
-from sqlalchemy import sql, schema, exceptions
-from sqlalchemy import logging
-from sqlalchemy.orm import util as mapperutil
-
"""Contains the ClauseSynchronizer class, which is used to map
attributes between two objects in a manner corresponding to a SQL
clause that compares column values.
"""
+from sqlalchemy import sql, schema, exceptions
+from sqlalchemy import logging
+from sqlalchemy.orm import util as mapperutil
+import operator
+
ONETOMANY = 0
MANYTOONE = 1
MANYTOMANY = 2
@@ -44,7 +43,7 @@ class ClauseSynchronizer(object):
def compile_binary(binary):
"""Assemble a SyncRule given a single binary condition."""
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
source_column = None
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index c6b0b2689..f59042810 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -19,15 +19,17 @@ new, dirty, or deleted and provides the capability to flush all those
changes at once.
"""
-from sqlalchemy import util, logging, topological
-from sqlalchemy.orm import attributes
+from sqlalchemy import util, logging, topological, exceptions
+from sqlalchemy.orm import attributes, interfaces
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper, class_mapper
-from sqlalchemy.exceptions import *
+from sqlalchemy.orm.mapper import object_mapper
import StringIO
import weakref
-class UOWEventHandler(attributes.AttributeExtension):
+# Load lazily
+object_session = None
+
+class UOWEventHandler(interfaces.AttributeExtension):
"""An event handler added to all class attributes which handles
session operations.
"""
@@ -37,52 +39,46 @@ class UOWEventHandler(attributes.AttributeExtension):
self.class_ = class_
self.cascade = cascade
- def append(self, event, obj, item):
+ def append(self, obj, item, initiator):
# process "save_update" cascade rules for when an instance is appended to the list of another instance
sess = object_session(obj)
if sess is not None:
if self.cascade is not None and self.cascade.save_update and item not in sess:
mapper = object_mapper(obj)
- prop = mapper.props[self.key]
+ prop = mapper.get_property(self.key)
ename = prop.mapper.entity_name
sess.save_or_update(item, entity_name=ename)
- def delete(self, event, obj, item):
+ def remove(self, obj, item, initiator):
# currently no cascade rules for removing an item from a list
# (i.e. it stays in the Session)
pass
- def set(self, event, obj, newvalue, oldvalue):
+ def set(self, obj, newvalue, oldvalue, initiator):
# process "save_update" cascade rules for when an instance is attached to another instance
sess = object_session(obj)
if sess is not None:
if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess:
mapper = object_mapper(obj)
- prop = mapper.props[self.key]
+ prop = mapper.get_property(self.key)
ename = prop.mapper.entity_name
sess.save_or_update(newvalue, entity_name=ename)
-class UOWProperty(attributes.InstrumentedAttribute):
- """Override ``InstrumentedAttribute`` to provide an extra
- ``AttributeExtension`` to all managed attributes as well as the
- `property` property.
- """
-
- def __init__(self, manager, class_, key, uselist, callable_, typecallable, cascade=None, extension=None, **kwargs):
- extension = util.to_list(extension or [])
- extension.insert(0, UOWEventHandler(key, class_, cascade=cascade))
- super(UOWProperty, self).__init__(manager, key, uselist, callable_, typecallable, extension=extension,**kwargs)
- self.class_ = class_
-
- property = property(lambda s:class_mapper(s.class_).props[s.key], doc="returns the MapperProperty object associated with this property")
class UOWAttributeManager(attributes.AttributeManager):
"""Override ``AttributeManager`` to provide the ``UOWProperty``
instance for all ``InstrumentedAttributes``.
"""
- def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs):
- return UOWProperty(self, class_, key, uselist, callable_, typecallable, **kwargs)
+ def create_prop(self, class_, key, uselist, callable_, typecallable,
+ cascade=None, extension=None, **kwargs):
+ extension = util.to_list(extension or [])
+ extension.insert(0, UOWEventHandler(key, class_, cascade=cascade))
+
+ return super(UOWAttributeManager, self).create_prop(
+ class_, key, uselist, callable_, typecallable,
+ extension=extension, **kwargs)
+
class UnitOfWork(object):
"""Main UOW object which stores lists of dirty/new/deleted objects.
@@ -122,7 +118,7 @@ class UnitOfWork(object):
def _validate_obj(self, obj):
if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \
(not hasattr(obj, '_instance_key') and obj not in self.new):
- raise InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj))
+ raise exceptions.InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj))
def _is_valid(self, obj):
if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \
@@ -138,7 +134,7 @@ class UnitOfWork(object):
self.new.remove(obj)
if not hasattr(obj, '_instance_key'):
mapper = object_mapper(obj)
- obj._instance_key = mapper.instance_key(obj)
+ obj._instance_key = mapper.identity_key_from_instance(obj)
if hasattr(obj, '_sa_insert_order'):
delattr(obj, '_sa_insert_order')
self.identity_map[obj._instance_key] = obj
@@ -148,7 +144,7 @@ class UnitOfWork(object):
"""register the given object as 'new' (i.e. unsaved) within this unit of work."""
if hasattr(obj, '_instance_key'):
- raise InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj))
+ raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj))
if obj not in self.new:
self.new.add(obj)
obj._sa_insert_order = len(self.new)
@@ -204,14 +200,14 @@ class UnitOfWork(object):
for obj in self.deleted.intersection(objset).difference(processed):
flush_context.register_object(obj, isdelete=True)
- trans = session.create_transaction(autoflush=False)
- flush_context.transaction = trans
+ session.create_transaction(autoflush=False)
+ flush_context.transaction = session.transaction
try:
flush_context.execute()
except:
- trans.rollback()
+ session.rollback()
raise
- trans.commit()
+ session.commit()
flush_context.post_exec()
@@ -228,6 +224,7 @@ class UOWTransaction(object):
def __init__(self, uow, session):
self.uow = uow
self.session = session
+ self.mapper_flush_opts = session._mapper_flush_opts
# stores tuples of mapper/dependent mapper pairs,
# representing a partial ordering fed into topological sort
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 3b3b9b7ed..d248c0dd0 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -4,7 +4,8 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions
+from sqlalchemy import sql, util, exceptions, sql_util
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_PASS
all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire", "none"])
@@ -89,8 +90,6 @@ class TranslatingDict(dict):
def __translate_col(self, col):
ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
-# if col is not ourcol and ourcol is not None:
-# print "TD TRANSLATING ", col, "TO", ourcol
if ourcol is None:
return col
else:
@@ -111,6 +110,56 @@ class TranslatingDict(dict):
def setdefault(self, col, value):
return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
+class ExtensionCarrier(MapperExtension):
+ def __init__(self, _elements=None):
+ self.__elements = _elements or []
+
+ def copy(self):
+ return ExtensionCarrier(list(self.__elements))
+
+ def __iter__(self):
+ return iter(self.__elements)
+
+ def insert(self, extension):
+ """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
+
+ self.__elements.insert(0, extension)
+
+ def append(self, extension):
+ """Append a MapperExtension at the end of this ExtensionCarrier's list."""
+
+ self.__elements.append(extension)
+
+ def _create_do(funcname):
+ def _do(self, *args, **kwargs):
+ for elem in self.__elements:
+ ret = getattr(elem, funcname)(*args, **kwargs)
+ if ret is not EXT_PASS:
+ return ret
+ else:
+ return EXT_PASS
+ return _do
+
+ init_instance = _create_do('init_instance')
+ init_failed = _create_do('init_failed')
+ dispose_class = _create_do('dispose_class')
+ get_session = _create_do('get_session')
+ load = _create_do('load')
+ get = _create_do('get')
+ get_by = _create_do('get_by')
+ select_by = _create_do('select_by')
+ select = _create_do('select')
+ translate_row = _create_do('translate_row')
+ create_instance = _create_do('create_instance')
+ append_result = _create_do('append_result')
+ populate_instance = _create_do('populate_instance')
+ before_insert = _create_do('before_insert')
+ before_update = _create_do('before_update')
+ after_update = _create_do('after_update')
+ after_insert = _create_do('after_insert')
+ before_delete = _create_do('before_delete')
+ after_delete = _create_do('after_delete')
+
class BinaryVisitor(sql.ClauseVisitor):
def __init__(self, func):
self.func = func
@@ -118,6 +167,138 @@ class BinaryVisitor(sql.ClauseVisitor):
def visit_binary(self, binary):
self.func(binary)
+class AliasedClauses(object):
+ """Creates aliases of a mapped tables for usage in ORM queries.
+ """
+
+ def __init__(self, mapped_table, alias=None):
+ if alias:
+ self.alias = alias
+ else:
+ self.alias = mapped_table.alias()
+ self.mapped_table = mapped_table
+ self.extra_cols = {}
+ self.row_decorator = self._create_row_adapter()
+
+ def aliased_column(self, column):
+ """return the aliased version of the given column, creating a new label for it if not already
+ present in this AliasedClauses."""
+
+ conv = self.alias.corresponding_column(column, raiseerr=False)
+ if conv:
+ return conv
+
+ if column in self.extra_cols:
+ return self.extra_cols[column]
+
+ aliased_column = column
+ # for column-level subqueries, swap out its selectable with our
+ # eager version as appropriate, and manually build the
+ # "correlation" list of the subquery.
+ class ModifySubquery(sql.ClauseVisitor):
+ def visit_select(s, select):
+ select._should_correlate = False
+ select.append_correlation(self.alias)
+ aliased_column = sql_util.ClauseAdapter(self.alias).chain(ModifySubquery()).traverse(aliased_column, clone=True)
+ aliased_column = aliased_column.label(None)
+ self.row_decorator.map[column] = aliased_column
+ # TODO: this is a little hacky
+ for attr in ('name', '_label'):
+ if hasattr(column, attr):
+ self.row_decorator.map[getattr(column, attr)] = aliased_column
+ self.extra_cols[column] = aliased_column
+ return aliased_column
+
+ def adapt_clause(self, clause):
+ return self.aliased_column(clause)
+# return sql_util.ClauseAdapter(self.alias).traverse(clause, clone=True)
+
+ def _create_row_adapter(self):
+ """Return a callable which,
+ when passed a RowProxy, will return a new dict-like object
+ that translates Column objects to that of this object's Alias before calling upon the row.
+
+ This allows a regular Table to be used to target columns in a row that was in reality generated from an alias
+ of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form
+ of the table.
+ """
+ class AliasedRowAdapter(object):
+ def __init__(self, row):
+ self.row = row
+ def __contains__(self, key):
+ return key in map or key in self.row
+ def has_key(self, key):
+ return key in self
+ def __getitem__(self, key):
+ if key in map:
+ key = map[key]
+ return self.row[key]
+ def keys(self):
+ return map.keys()
+ map = {}
+ for c in self.alias.c:
+ parent = self.mapped_table.corresponding_column(c)
+ map[parent] = c
+ map[parent._label] = c
+ map[parent.name] = c
+ for c in self.extra_cols:
+ map[c] = self.extra_cols[c]
+ # TODO: this is a little hacky
+ for attr in ('name', '_label'):
+ if hasattr(c, attr):
+ map[getattr(c, attr)] = self.extra_cols[c]
+
+ AliasedRowAdapter.map = map
+ return AliasedRowAdapter
+
+
+class PropertyAliasedClauses(AliasedClauses):
+ """extends AliasedClauses to add support for primary/secondary joins on a relation()."""
+
+ def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None):
+ super(PropertyAliasedClauses, self).__init__(prop.select_table)
+
+ self.parentclauses = parentclauses
+ if parentclauses is not None:
+ self.path = parentclauses.path + (prop.parent, prop.key)
+ else:
+ self.path = (prop.parent, prop.key)
+
+ self.prop = prop
+
+ if prop.secondary:
+ self.secondary = prop.secondary.alias()
+ if parentclauses is not None:
+ aliasizer = sql_util.ClauseAdapter(self.alias).\
+ chain(sql_util.ClauseAdapter(self.secondary)).\
+ chain(sql_util.ClauseAdapter(parentclauses.alias))
+ else:
+ aliasizer = sql_util.ClauseAdapter(self.alias).\
+ chain(sql_util.ClauseAdapter(self.secondary))
+ self.secondaryjoin = aliasizer.traverse(secondaryjoin, clone=True)
+ self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True)
+ else:
+ if parentclauses is not None:
+ aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
+ aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side))
+ else:
+ aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
+ self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True)
+ self.secondary = None
+ self.secondaryjoin = None
+
+ if prop.order_by:
+ self.order_by = sql_util.ClauseAdapter(self.alias).copy_and_process(util.to_list(prop.order_by))
+ else:
+ self.order_by = None
+
+ mapper = property(lambda self:self.prop.mapper)
+ table = property(lambda self:self.prop.select_table)
+
+ def __str__(self):
+ return "->".join([str(s) for s in self.path])
+
+
def instance_str(instance):
"""Return a string describing an instance."""
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index 8670464a0..f86e14ab1 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -13,7 +13,7 @@ automatically, based on module type and connect arguments, simply by
calling regular DBAPI connect() methods.
"""
-import weakref, string, time, sys, traceback
+import weakref, time
try:
import cPickle as pickle
except:
@@ -190,6 +190,7 @@ class _ConnectionRecord(object):
def __init__(self, pool):
self.__pool = pool
self.connection = self.__connect()
+ self.properties = {}
def close(self):
if self.connection is not None:
@@ -207,10 +208,12 @@ class _ConnectionRecord(object):
def get_connection(self):
if self.connection is None:
self.connection = self.__connect()
+ self.properties.clear()
elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle):
self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
self.__close()
self.connection = self.__connect()
+ self.properties.clear()
return self.connection
def __close(self):
@@ -257,6 +260,21 @@ class _ConnectionFairy(object):
_logger = property(lambda self: self._pool.logger)
is_valid = property(lambda self:self.connection is not None)
+
+ def _get_properties(self):
+ """A property collection unique to this DBAPI connection."""
+
+ try:
+ return self._connection_record.properties
+ except AttributeError:
+ if self.connection is None:
+ raise exceptions.InvalidRequestError("This connection is closed")
+ try:
+ return self._detatched_properties
+ except AttributeError:
+ self._detatched_properties = value = {}
+ return value
+ properties = property(_get_properties)
def invalidate(self, e=None):
"""Mark this connection as invalidated.
@@ -301,6 +319,8 @@ class _ConnectionFairy(object):
if self._connection_record is not None:
self._connection_record.connection = None
self._pool.do_return_conn(self._connection_record)
+ self._detatched_properties = \
+ self._connection_record.properties.copy()
self._connection_record = None
def close_open_cursors(self):
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index f6ad52adc..3faa3b89c 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -17,17 +17,19 @@ objects as well as the visitor interface, so that the schema package
*plugs in* to the SQL package.
"""
-from sqlalchemy import sql, types, exceptions, util, databases
+from sqlalchemy import sql, types, exceptions,util, databases
import sqlalchemy
-import copy, re, string
+import re, string, inspect
__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint',
'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'DefaultGenerator', 'Constraint',
- 'MetaData', 'ThreadLocalMetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
+ 'MetaData', 'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
class SchemaItem(object):
"""Base class for items that define a database schema."""
+ __metaclass__ = sql._FigureVisitName
+
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
@@ -69,15 +71,7 @@ class SchemaItem(object):
m = self._derived_metadata()
return m and m.bind or None
- def get_engine(self):
- """Return the engine or raise an error if no engine.
-
- Deprecated. use the "bind" attribute.
- """
-
- return self._get_engine(raiseerr=True)
-
- def _set_casing_strategy(self, name, kwargs, keyname='case_sensitive'):
+ def _set_casing_strategy(self, kwargs, keyname='case_sensitive'):
"""Set the "case_sensitive" argument sent via keywords to the item's constructor.
For the purposes of Table's 'schema' property, the name of the
@@ -85,7 +79,7 @@ class SchemaItem(object):
"""
setattr(self, '_%s_setting' % keyname, kwargs.pop(keyname, None))
- def _determine_case_sensitive(self, name, keyname='case_sensitive'):
+ def _determine_case_sensitive(self, keyname='case_sensitive'):
"""Determine the `case_sensitive` value for this item.
For the purposes of Table's `schema` property, the name of the
@@ -111,16 +105,22 @@ class SchemaItem(object):
return True
def _get_case_sensitive(self):
+ """late-compile the 'case-sensitive' setting when first accessed.
+
+ typically the SchemaItem will be assembled into its final structure
+ of other SchemaItems at this point, whereby it can attain this setting
+ from its containing SchemaItem if not defined locally.
+ """
+
try:
return self.__case_sensitive
except AttributeError:
- self.__case_sensitive = self._determine_case_sensitive(self.name)
+ self.__case_sensitive = self._determine_case_sensitive()
return self.__case_sensitive
case_sensitive = property(_get_case_sensitive)
- engine = property(lambda s:s._get_engine())
metadata = property(lambda s:s._derived_metadata())
- bind = property(lambda s:s.engine)
+ bind = property(lambda s:s._get_engine())
def _get_table_key(name, schema):
if schema is None:
@@ -128,30 +128,16 @@ def _get_table_key(name, schema):
else:
return schema + "." + name
-class _TableSingleton(type):
+class _TableSingleton(sql._FigureVisitName):
"""A metaclass used by the ``Table`` object to provide singleton behavior."""
def __call__(self, name, metadata, *args, **kwargs):
- if isinstance(metadata, sql.Executor):
- # backwards compatibility - get a BoundSchema associated with the engine
- engine = metadata
- if not hasattr(engine, '_legacy_metadata'):
- engine._legacy_metadata = MetaData(engine)
- metadata = engine._legacy_metadata
- elif metadata is not None and not isinstance(metadata, MetaData):
- # they left MetaData out, so assume its another SchemaItem, add it to *args
- args = list(args)
- args.insert(0, metadata)
- metadata = None
-
- if metadata is None:
- metadata = default_metadata
-
schema = kwargs.get('schema', None)
autoload = kwargs.pop('autoload', False)
autoload_with = kwargs.pop('autoload_with', False)
mustexist = kwargs.pop('mustexist', False)
useexisting = kwargs.pop('useexisting', False)
+ include_columns = kwargs.pop('include_columns', None)
key = _get_table_key(name, schema)
try:
table = metadata.tables[key]
@@ -170,9 +156,9 @@ class _TableSingleton(type):
if autoload:
try:
if autoload_with:
- autoload_with.reflecttable(table)
+ autoload_with.reflecttable(table, include_columns=include_columns)
else:
- metadata._get_engine(raiseerr=True).reflecttable(table)
+ metadata._get_engine(raiseerr=True).reflecttable(table, include_columns=include_columns)
except exceptions.NoSuchTableError:
del metadata.tables[key]
raise
@@ -187,7 +173,7 @@ class Table(SchemaItem, sql.TableClause):
This subclasses ``sql.TableClause`` to provide a table that is
associated with an instance of ``MetaData``, which in turn
- may be associated with an instance of ``SQLEngine``.
+ may be associated with an instance of ``Engine``.
Whereas ``TableClause`` represents a table as its used in an SQL
expression, ``Table`` represents a table as it exists in a
@@ -232,16 +218,28 @@ class Table(SchemaItem, sql.TableClause):
options include:
schema
- Defaults to None: the *schema name* for this table, which is
+ The *schema name* for this table, which is
required if the table resides in a schema other than the
default selected schema for the engine's database
- connection.
+ connection. Defaults to ``None``.
autoload
Defaults to False: the Columns for this table should be
reflected from the database. Usually there will be no
Column objects in the constructor if this property is set.
+ autoload_with
+ if autoload==True, this is an optional Engine or Connection
+ instance to be used for the table reflection. If ``None``,
+ the underlying MetaData's bound connectable will be used.
+
+ include_columns
+ A list of strings indicating a subset of columns to be
+ loaded via the ``autoload`` operation; table columns who
+ aren't present in this list will not be represented on the resulting
+ ``Table`` object. Defaults to ``None`` which indicates all
+ columns should be reflected.
+
mustexist
Defaults to False: indicates that this Table must already
have been defined elsewhere in the application, else an
@@ -293,8 +291,8 @@ class Table(SchemaItem, sql.TableClause):
self.fullname = self.name
self.owner = kwargs.pop('owner', None)
- self._set_casing_strategy(name, kwargs)
- self._set_casing_strategy(self.schema or '', kwargs, keyname='case_sensitive_schema')
+ self._set_casing_strategy(kwargs)
+ self._set_casing_strategy(kwargs, keyname='case_sensitive_schema')
if len([k for k in kwargs if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]):
raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
@@ -302,6 +300,8 @@ class Table(SchemaItem, sql.TableClause):
# store extra kwargs, which should only contain db-specific options
self.kwargs = kwargs
+ key = property(lambda self:_get_table_key(self.name, self.schema))
+
def _export_columns(self, columns=None):
# override FromClause's collection initialization logic; TableClause and Table
# implement it differently
@@ -311,7 +311,7 @@ class Table(SchemaItem, sql.TableClause):
try:
return getattr(self, '_case_sensitive_schema')
except AttributeError:
- setattr(self, '_case_sensitive_schema', self._determine_case_sensitive(self.schema or '', keyname='case_sensitive_schema'))
+ setattr(self, '_case_sensitive_schema', self._determine_case_sensitive(keyname='case_sensitive_schema'))
return getattr(self, '_case_sensitive_schema')
case_sensitive_schema = property(_get_case_sensitive_schema)
@@ -361,36 +361,28 @@ class Table(SchemaItem, sql.TableClause):
else:
return []
- def exists(self, bind=None, connectable=None):
+ def exists(self, bind=None):
"""Return True if this table exists."""
- if connectable is not None:
- bind = connectable
-
if bind is None:
bind = self._get_engine(raiseerr=True)
def do(conn):
- e = conn.engine
- return e.dialect.has_table(conn, self.name, schema=self.schema)
+ return conn.dialect.has_table(conn, self.name, schema=self.schema)
return bind.run_callable(do)
- def create(self, bind=None, checkfirst=False, connectable=None):
+ def create(self, bind=None, checkfirst=False):
"""Issue a ``CREATE`` statement for this table.
See also ``metadata.create_all()``."""
- if connectable is not None:
- bind = connectable
self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self])
- def drop(self, bind=None, checkfirst=False, connectable=None):
+ def drop(self, bind=None, checkfirst=False):
"""Issue a ``DROP`` statement for this table.
See also ``metadata.drop_all()``."""
- if connectable is not None:
- bind = connectable
self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self])
def tometadata(self, metadata, schema=None):
@@ -417,7 +409,7 @@ class Column(SchemaItem, sql._ColumnClause):
``TableClause``/``Table``.
"""
- def __init__(self, name, type, *args, **kwargs):
+ def __init__(self, name, type_, *args, **kwargs):
"""Construct a new ``Column`` object.
Arguments are:
@@ -426,7 +418,7 @@ class Column(SchemaItem, sql._ColumnClause):
The name of this column. This should be the identical name
as it appears, or will appear, in the database.
- type
+ type\_
The ``TypeEngine`` for this column. This can be any
subclass of ``types.AbstractType``, including the
database-agnostic types defined in the types module,
@@ -516,7 +508,7 @@ class Column(SchemaItem, sql._ColumnClause):
identifier contains mixed case.
"""
- super(Column, self).__init__(name, None, type)
+ super(Column, self).__init__(name, None, type_)
self.args = args
self.key = kwargs.pop('key', name)
self._primary_key = kwargs.pop('primary_key', False)
@@ -526,7 +518,7 @@ class Column(SchemaItem, sql._ColumnClause):
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
self.quote = kwargs.pop('quote', False)
- self._set_casing_strategy(name, kwargs)
+ self._set_casing_strategy(kwargs)
self.onupdate = kwargs.pop('onupdate', None)
self.autoincrement = kwargs.pop('autoincrement', True)
self.constraints = util.Set()
@@ -631,12 +623,13 @@ class Column(SchemaItem, sql._ColumnClause):
This is a copy of this ``Column`` referenced by a different parent
(such as an alias or select statement).
"""
+
fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
c.table = selectable
c.orig_set = self.orig_set
- c._distance = self._distance + 1
c.__originating_column = self.__originating_column
+ c._distance = self._distance + 1
if not c._is_oid:
selectable.columns.add(c)
if self.primary_key:
@@ -749,18 +742,14 @@ class ForeignKey(SchemaItem):
raise exceptions.ArgumentError("Could not create ForeignKey '%s' on table '%s': table '%s' has no column named '%s'" % (self._colspec, parenttable.name, table.name, str(e)))
else:
self._column = self._colspec
+
# propigate TypeEngine to parent if it didnt have one
- if self.parent.type is types.NULLTYPE:
+ if isinstance(self.parent.type, types.NullType):
self.parent.type = self._column.type
return self._column
column = property(lambda s: s._init_column())
- def accept_visitor(self, visitor):
- """Call the `visit_foreign_key` method on the given visitor."""
-
- visitor.visit_foreign_key(self)
-
def _get_parent(self):
return self.parent
@@ -802,7 +791,7 @@ class DefaultGenerator(SchemaItem):
def execute(self, bind=None, **kwargs):
if bind is None:
bind = self._get_engine(raiseerr=True)
- return bind.execute_default(self, **kwargs)
+ return bind._execute_default(self, **kwargs)
def __repr__(self):
return "DefaultGenerator()"
@@ -814,9 +803,6 @@ class PassiveDefault(DefaultGenerator):
super(PassiveDefault, self).__init__(**kwargs)
self.arg = arg
- def accept_visitor(self, visitor):
- return visitor.visit_passive_default(self)
-
def __repr__(self):
return "PassiveDefault(%s)" % repr(self.arg)
@@ -829,15 +815,26 @@ class ColumnDefault(DefaultGenerator):
def __init__(self, arg, **kwargs):
super(ColumnDefault, self).__init__(**kwargs)
- self.arg = arg
-
- def accept_visitor(self, visitor):
- """Call the visit_column_default method on the given visitor."""
+ if callable(arg):
+ if not inspect.isfunction(arg):
+ self.arg = lambda ctx: arg()
+ else:
+ argspec = inspect.getargspec(arg)
+ if len(argspec[0]) == 0:
+ self.arg = lambda ctx: arg()
+ elif len(argspec[0]) != 1:
+ raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments")
+ else:
+ self.arg = arg
+ else:
+ self.arg = arg
+ def _visit_name(self):
if self.for_update:
- return visitor.visit_column_onupdate(self)
+ return "column_onupdate"
else:
- return visitor.visit_column_default(self)
+ return "column_default"
+ __visit_name__ = property(_visit_name)
def __repr__(self):
return "ColumnDefault(%s)" % repr(self.arg)
@@ -852,7 +849,7 @@ class Sequence(DefaultGenerator):
self.increment = increment
self.optional=optional
self.quote = quote
- self._set_casing_strategy(name, kwargs)
+ self._set_casing_strategy(kwargs)
def __repr__(self):
return "Sequence(%s)" % string.join(
@@ -864,20 +861,16 @@ class Sequence(DefaultGenerator):
super(Sequence, self)._set_parent(column)
column.sequence = self
- def create(self, bind=None):
+ def create(self, bind=None, checkfirst=True):
if bind is None:
bind = self._get_engine(raiseerr=True)
- bind.create(self)
+ bind.create(self, checkfirst=checkfirst)
- def drop(self, bind=None):
+ def drop(self, bind=None, checkfirst=True):
if bind is None:
bind = self._get_engine(raiseerr=True)
- bind.drop(self)
-
- def accept_visitor(self, visitor):
- """Call the visit_seauence method on the given visitor."""
+ bind.drop(self, checkfirst=checkfirst)
- return visitor.visit_sequence(self)
class Constraint(SchemaItem):
"""Represent a table-level ``Constraint`` such as a composite primary key, foreign key, or unique constraint.
@@ -891,7 +884,7 @@ class Constraint(SchemaItem):
self.columns = sql.ColumnCollection()
def __contains__(self, x):
- return x in self.columns
+ return self.columns.contains_column(x)
def keys(self):
return self.columns.keys()
@@ -916,11 +909,12 @@ class CheckConstraint(Constraint):
super(CheckConstraint, self).__init__(name)
self.sqltext = sqltext
- def accept_visitor(self, visitor):
+ def _visit_name(self):
if isinstance(self.parent, Table):
- visitor.visit_check_constraint(self)
+ return "check_constraint"
else:
- visitor.visit_column_check_constraint(self)
+ return "column_check_constraint"
+ __visit_name__ = property(_visit_name)
def _set_parent(self, parent):
self.parent = parent
@@ -949,9 +943,6 @@ class ForeignKeyConstraint(Constraint):
for (c, r) in zip(self.__colnames, self.__refcolnames):
self.append_element(c,r)
- def accept_visitor(self, visitor):
- visitor.visit_foreign_key_constraint(self)
-
def append_element(self, col, refcol):
fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
fk._set_parent(self.table.c[col])
@@ -975,9 +966,6 @@ class PrimaryKeyConstraint(Constraint):
for c in self.__colnames:
self.append_column(table.c[c])
- def accept_visitor(self, visitor):
- visitor.visit_primary_key_constraint(self)
-
def add(self, col):
self.append_column(col)
@@ -1009,9 +997,6 @@ class UniqueConstraint(Constraint):
def append_column(self, col):
self.columns.add(col)
- def accept_visitor(self, visitor):
- visitor.visit_unique_constraint(self)
-
def copy(self):
return UniqueConstraint(name=self.name, *self.__colnames)
@@ -1075,22 +1060,19 @@ class Index(SchemaItem):
% (self.name, column))
self.columns.append(column)
- def create(self, connectable=None):
- if connectable is not None:
- connectable.create(self)
+ def create(self, bind=None):
+ if bind is not None:
+ bind.create(self)
else:
self._get_engine(raiseerr=True).create(self)
return self
- def drop(self, connectable=None):
- if connectable is not None:
- connectable.drop(self)
+ def drop(self, bind=None):
+ if bind is not None:
+ bind.drop(self)
else:
self._get_engine(raiseerr=True).drop(self)
- def accept_visitor(self, visitor):
- visitor.visit_index(self)
-
def __str__(self):
return repr(self)
@@ -1103,69 +1085,34 @@ class Index(SchemaItem):
class MetaData(SchemaItem):
"""Represent a collection of Tables and their associated schema constructs."""
- def __init__(self, engine_or_url=None, url=None, bind=None, engine=None, **kwargs):
+ __visit_name__ = 'metadata'
+
+ def __init__(self, bind=None, **kwargs):
"""create a new MetaData object.
bind
an Engine, or a string or URL instance which will be passed
- to create_engine(), along with \**kwargs - this MetaData will
- be bound to the resulting engine.
+ to create_engine(), this MetaData will be bound to the resulting
+ engine.
- engine_or_url
- deprecated; a synonym for 'bind'
-
- url
- deprecated. a string or URL instance which will be passed to
- create_engine(), along with \**kwargs - this MetaData will be
- bound to the resulting engine.
-
- engine
- deprecated. an Engine instance to which this MetaData will
- be bound.
-
case_sensitive
- popped from \**kwargs, indicates default case sensitive
- setting for all contained objects. defaults to True.
+ popped from \**kwargs, indicates default case sensitive setting for
+ all contained objects. defaults to True.
- name
- deprecated, optional name for this MetaData instance.
-
- """
-
- # transition from <= 0.3.8 signature:
- # MetaData(name=None, url=None, engine=None)
- # to 0.4 signature:
- # MetaData(engine_or_url=None)
- name = kwargs.get('name', None)
- if engine_or_url is None:
- engine_or_url = url or bind or engine
- elif 'name' in kwargs:
- engine_or_url = engine_or_url or bind or engine or url
- else:
- import sqlalchemy.engine as engine
- import sqlalchemy.engine.url as url
- if (not isinstance(engine_or_url, url.URL) and
- not isinstance(engine_or_url, engine.Connectable)):
- try:
- url.make_url(engine_or_url)
- except exceptions.ArgumentError:
- # nope, must have been a name as 1st positional
- name, engine_or_url = engine_or_url, (url or engine or bind)
- kwargs.pop('name', None)
+ """
self.tables = {}
- self.name = name
- self._bind = None
- self._set_casing_strategy(name, kwargs)
- if engine_or_url:
- self.connect(engine_or_url, **kwargs)
+ self._set_casing_strategy(kwargs)
+ self.bind = bind
+
+ def __repr__(self):
+ return 'MetaData(%r)' % self.bind
def __getstate__(self):
- return {'tables':self.tables, 'name':self.name, 'casesensitive':self._case_sensitive_setting}
-
+ return {'tables':self.tables, 'casesensitive':self._case_sensitive_setting}
+
def __setstate__(self, state):
self.tables = state['tables']
- self.name = state['name']
self._case_sensitive_setting = state['casesensitive']
self._bind = None
@@ -1173,7 +1120,7 @@ class MetaData(SchemaItem):
"""return True if this MetaData is bound to an Engine."""
return self._bind is not None
- def connect(self, bind=None, **kwargs):
+ def connect(self, bind, **kwargs):
"""bind this MetaData to an Engine.
DEPRECATED. use metadata.bind = <engine> or metadata.bind = <url>.
@@ -1184,13 +1131,8 @@ class MetaData(SchemaItem):
produce the engine which to connect to. otherwise connects
directly to the given Engine.
- engine_or_url
- deprecated. synonymous with "bind"
-
"""
- if bind is None:
- bind = kwargs.pop('engine_or_url', None)
from sqlalchemy.engine.url import URL
if isinstance(bind, (basestring, URL)):
self._bind = sqlalchemy.create_engine(bind, **kwargs)
@@ -1202,6 +1144,10 @@ class MetaData(SchemaItem):
def clear(self):
self.tables.clear()
+ def remove(self, table):
+ # TODO: scan all other tables and remove FK _column
+ del self.tables[table.key]
+
def table_iterator(self, reverse=True, tables=None):
import sqlalchemy.sql_util
if tables is None:
@@ -1214,7 +1160,7 @@ class MetaData(SchemaItem):
def _get_parent(self):
return None
- def create_all(self, bind=None, tables=None, checkfirst=True, connectable=None):
+ def create_all(self, bind=None, tables=None, checkfirst=True):
"""Create all tables stored in this metadata.
This will conditionally create tables depending on if they do
@@ -1224,21 +1170,16 @@ class MetaData(SchemaItem):
A ``Connectable`` used to access the database; if None, uses
the existing bind on this ``MetaData``, if any.
- connectable
- deprecated. synonymous with "bind"
-
tables
Optional list of tables, which is a subset of the total
tables in the ``MetaData`` (others are ignored).
"""
- if connectable is not None:
- bind = connectable
if bind is None:
bind = self._get_engine(raiseerr=True)
bind.create(self, checkfirst=checkfirst, tables=tables)
- def drop_all(self, bind=None, tables=None, checkfirst=True, connectable=None):
+ def drop_all(self, bind=None, tables=None, checkfirst=True):
"""Drop all tables stored in this metadata.
This will conditionally drop tables depending on if they
@@ -1248,23 +1189,15 @@ class MetaData(SchemaItem):
A ``Connectable`` used to access the database; if None, uses
the existing bind on this ``MetaData``, if any.
- connectable
- deprecated. synonymous with "bind"
-
tables
Optional list of tables, which is a subset of the total
tables in the ``MetaData`` (others are ignored).
"""
- if connectable is not None:
- bind = connectable
if bind is None:
bind = self._get_engine(raiseerr=True)
bind.drop(self, checkfirst=checkfirst, tables=tables)
- def accept_visitor(self, visitor):
- visitor.visit_metadata(self)
-
def _derived_metadata(self):
return self
@@ -1276,27 +1209,22 @@ class MetaData(SchemaItem):
return None
return self._bind
-
-class BoundMetaData(MetaData):
- """Deprecated. Use ``MetaData``."""
-
- def __init__(self, engine_or_url, name=None, **kwargs):
- super(BoundMetaData, self).__init__(engine_or_url=engine_or_url,
- name=name, **kwargs)
-
-
class ThreadLocalMetaData(MetaData):
- """A ``MetaData`` that binds to multiple ``Engine`` implementations on a thread-local basis."""
+ """Build upon ``MetaData`` to provide the capability to bind to
+multiple ``Engine`` implementations on a dynamically alterable,
+thread-local basis.
+ """
+
+ __visit_name__ = 'metadata'
- def __init__(self, name=None, **kwargs):
+ def __init__(self, **kwargs):
self.context = util.ThreadLocal()
self.__engines = {}
- super(ThreadLocalMetaData, self).__init__(engine_or_url=None,
- name=name, **kwargs)
+ super(ThreadLocalMetaData, self).__init__(**kwargs)
def connect(self, engine_or_url, **kwargs):
from sqlalchemy.engine.url import URL
- if isinstance(engine_or_url, basestring) or isinstance(engine_or_url, URL):
+ if isinstance(engine_or_url, (basestring, URL)):
try:
self.context._engine = self.__engines[engine_or_url]
except KeyError:
@@ -1304,6 +1232,8 @@ class ThreadLocalMetaData(MetaData):
self.__engines[engine_or_url] = e
self.context._engine = e
else:
+ # TODO: this is squirrely. we shouldnt have to hold onto engines
+ # in a case like this
if not self.__engines.has_key(engine_or_url):
self.__engines[engine_or_url] = engine_or_url
self.context._engine = engine_or_url
@@ -1325,77 +1255,10 @@ class ThreadLocalMetaData(MetaData):
raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
else:
return None
- engine=property(_get_engine)
bind = property(_get_engine, connect)
-def DynamicMetaData(name=None, threadlocal=True, **kw):
- """Deprecated. Use ``MetaData`` or ``ThreadLocalMetaData``."""
-
- if threadlocal:
- return ThreadLocalMetaData(name=name, **kw)
- else:
- return MetaData(name=name, **kw)
-
class SchemaVisitor(sql.ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
__traverse_options__ = {'schema_visitor':True}
-
- def visit_schema(self, schema):
- """Visit a generic ``SchemaItem``."""
- pass
-
- def visit_table(self, table):
- """Visit a ``Table``."""
- pass
-
- def visit_column(self, column):
- """Visit a ``Column``."""
- pass
-
- def visit_foreign_key(self, join):
- """Visit a ``ForeignKey``."""
- pass
-
- def visit_index(self, index):
- """Visit an ``Index``."""
- pass
-
- def visit_passive_default(self, default):
- """Visit a passive default."""
- pass
-
- def visit_column_default(self, default):
- """Visit a ``ColumnDefault``."""
- pass
-
- def visit_column_onupdate(self, onupdate):
- """Visit a ``ColumnDefault`` with the `for_update` flag set."""
- pass
-
- def visit_sequence(self, sequence):
- """Visit a ``Sequence``."""
- pass
-
- def visit_primary_key_constraint(self, constraint):
- """Visit a ``PrimaryKeyConstraint``."""
- pass
-
- def visit_foreign_key_constraint(self, constraint):
- """Visit a ``ForeignKeyConstraint``."""
- pass
-
- def visit_unique_constraint(self, constraint):
- """Visit a ``UniqueConstraint``."""
- pass
-
- def visit_check_constraint(self, constraint):
- """Visit a ``CheckConstraint``."""
- pass
-
- def visit_column_check_constraint(self, constraint):
- """Visit a ``CheckConstraint`` on a ``Column``."""
- pass
-
-default_metadata = ThreadLocalMetaData(name='default')
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 8b454947e..01588e92d 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -24,54 +24,21 @@ are less guaranteed to stay the same in future releases.
"""
-from sqlalchemy import util, exceptions, logging
+from sqlalchemy import util, exceptions
from sqlalchemy import types as sqltypes
-import string, re, random, sets
+import re, operator
-
-__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
+__all__ = ['Alias', 'ClauseElement', 'ClauseParameters',
'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
- 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join',
- 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc',
- 'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
+ 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
+ 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
+ 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
'insert', 'intersect', 'intersect_all', 'join', 'literal',
- 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
+ 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select',
'subquery', 'table', 'text', 'union', 'union_all', 'update',]
-# precedence ordering for common operators. if an operator is not present in this list,
-# it will be parenthesized when grouped against other operators
-PRECEDENCE = {
- 'FROM':15,
- '*':7,
- '/':7,
- '%':7,
- '+':6,
- '-':6,
- 'ILIKE':5,
- 'NOT ILIKE':5,
- 'LIKE':5,
- 'NOT LIKE':5,
- 'IN':5,
- 'NOT IN':5,
- 'IS':5,
- 'IS NOT':5,
- '=':5,
- '!=':5,
- '>':5,
- '<':5,
- '>=':5,
- '<=':5,
- 'BETWEEN':5,
- 'NOT':4,
- 'AND':3,
- 'OR':2,
- ',':-1,
- 'AS':-1,
- 'EXISTS':0,
- '_smallest': -1000,
- '_largest': 1000
-}
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
def desc(column):
"""Return a descending ``ORDER BY`` clause element.
@@ -141,7 +108,7 @@ def join(left, right, onclause=None, **kwargs):
return Join(left, right, onclause, **kwargs)
-def select(columns=None, whereclause = None, from_obj = [], **kwargs):
+def select(columns=None, whereclause=None, from_obj=[], **kwargs):
"""Returns a ``SELECT`` clause element.
Similar functionality is also available via the ``select()`` method on any
@@ -224,9 +191,6 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
automatically bind to whatever ``Connectable`` instances can be located
within its contained ``ClauseElement`` members.
- engine=None
- deprecated. a synonym for "bind".
-
limit=None
a numerical value which usually compiles to a ``LIMIT`` expression
in the resulting select. Databases that don't support ``LIMIT``
@@ -238,12 +202,8 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
will attempt to provide similar functionality.
scalar=False
- when ``True``, indicates that the resulting ``Select`` object
- is to be used in the "columns" clause of another select statement,
- where the evaluated value of the column is the scalar result of
- this statement. Normally, placing any ``Selectable`` within the
- columns clause of a ``select()`` call will expand the member
- columns of the ``Selectable`` individually.
+ deprecated. use select(...).as_scalar() to create a "scalar column"
+ proxy for an existing Select object.
correlate=True
indicates that this ``Select`` object should have its contained
@@ -254,8 +214,12 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
rendered in the ``FROM`` clause of this select statement.
"""
-
- return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs)
+ scalar = kwargs.pop('scalar', False)
+ s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
+ if scalar:
+ return s.as_scalar()
+ else:
+ return s
def subquery(alias, *args, **kwargs):
"""Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select].
@@ -271,7 +235,7 @@ def subquery(alias, *args, **kwargs):
return Select(*args, **kwargs).alias(alias)
def insert(table, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Insert] clause element.
+ """Return an [sqlalchemy.sql#Insert] clause element.
Similar functionality is available via the ``insert()``
method on [sqlalchemy.schema#Table].
@@ -304,10 +268,10 @@ def insert(table, values = None, **kwargs):
against the ``INSERT`` statement.
"""
- return _Insert(table, values, **kwargs)
+ return Insert(table, values, **kwargs)
def update(table, whereclause = None, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Update] clause element.
+ """Return an [sqlalchemy.sql#Update] clause element.
Similar functionality is available via the ``update()``
method on [sqlalchemy.schema#Table].
@@ -344,10 +308,10 @@ def update(table, whereclause = None, values = None, **kwargs):
against the ``UPDATE`` statement.
"""
- return _Update(table, whereclause, values, **kwargs)
+ return Update(table, whereclause, values, **kwargs)
def delete(table, whereclause = None, **kwargs):
- """Return a [sqlalchemy.sql#_Delete] clause element.
+ """Return a [sqlalchemy.sql#Delete] clause element.
Similar functionality is available via the ``delete()``
method on [sqlalchemy.schema#Table].
@@ -361,7 +325,7 @@ def delete(table, whereclause = None, **kwargs):
"""
- return _Delete(table, whereclause, **kwargs)
+ return Delete(table, whereclause, **kwargs)
def and_(*clauses):
"""Join a list of clauses together using the ``AND`` operator.
@@ -371,7 +335,7 @@ def and_(*clauses):
"""
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='AND', *clauses)
+ return ClauseList(operator=operator.and_, *clauses)
def or_(*clauses):
"""Join a list of clauses together using the ``OR`` operator.
@@ -382,7 +346,7 @@ def or_(*clauses):
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='OR', *clauses)
+ return ClauseList(operator=operator.or_, *clauses)
def not_(clause):
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
@@ -391,7 +355,7 @@ def not_(clause):
subclasses to produce the same result.
"""
- return clause._negate()
+ return operator.inv(clause)
def distinct(expr):
"""return a ``DISTINCT`` clause."""
@@ -407,12 +371,9 @@ def between(ctest, cleft, cright):
provides similar functionality.
"""
- return _BinaryExpression(ctest, ClauseList(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type), operator='AND', group=False), 'BETWEEN')
+ ctest = _literal_as_binds(ctest)
+ return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op)
-def between_(*args, **kwargs):
- """synonym for [sqlalchemy.sql#between()] (deprecated)."""
-
- return between(*args, **kwargs)
def case(whens, value=None, else_=None):
"""Produce a ``CASE`` statement.
@@ -435,7 +396,7 @@ def case(whens, value=None, else_=None):
type = list(whenlist[-1])[-1].type
else:
type = None
- cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END'])
+ cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END'])
return cc
def cast(clause, totype, **kwargs):
@@ -457,7 +418,7 @@ def cast(clause, totype, **kwargs):
def extract(field, expr):
"""Return the clause ``extract(field FROM expr)``."""
- expr = _BinaryExpression(text(field), expr, "FROM")
+ expr = _BinaryExpression(text(field), expr, Operators.from_)
return func.extract(expr)
def exists(*args, **kwargs):
@@ -587,7 +548,7 @@ def alias(selectable, alias=None):
return Alias(selectable, alias=alias)
-def literal(value, type=None):
+def literal(value, type_=None):
"""Return a literal clause, bound to a bind parameter.
Literal clauses are created automatically when non-
@@ -603,13 +564,13 @@ def literal(value, type=None):
the underlying DBAPI, or is translatable via the given type
argument.
- type
+ type\_
an optional [sqlalchemy.types#TypeEngine] which will provide
bind-parameter translation for this literal.
"""
- return _BindParamClause('literal', value, type=type, unique=True)
+ return _BindParamClause('literal', value, type_=type_, unique=True)
def label(name, obj):
"""Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement].
@@ -630,7 +591,7 @@ def label(name, obj):
return _Label(name, obj)
-def column(text, type=None):
+def column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
@@ -644,15 +605,15 @@ def column(text, type=None):
constructs that are not to be quoted, use the [sqlalchemy.sql#literal_column()]
function.
- type
+ type\_
an optional [sqlalchemy.types#TypeEngine] object which will provide
result-set translation for this column.
"""
- return _ColumnClause(text, type=type)
+ return _ColumnClause(text, type_=type_)
-def literal_column(text, type=None):
+def literal_column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
@@ -674,7 +635,7 @@ def literal_column(text, type=None):
"""
- return _ColumnClause(text, type=type, is_literal=True)
+ return _ColumnClause(text, type_=type_, is_literal=True)
def table(name, *columns):
"""Return a [sqlalchemy.sql#Table] object.
@@ -685,7 +646,7 @@ def table(name, *columns):
return TableClause(name, *columns)
-def bindparam(key, value=None, type=None, shortname=None, unique=False):
+def bindparam(key, value=None, type_=None, shortname=None, unique=False):
"""Create a bind parameter clause with the given key.
value
@@ -707,11 +668,22 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False):
"""
if isinstance(key, _ColumnClause):
- return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique)
+ return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique)
else:
- return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique)
+ return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique)
-def text(text, bind=None, engine=None, *args, **kwargs):
+def outparam(key, type_=None):
+ """create an 'OUT' parameter for usage in functions (stored procedures), for databases
+ whith support them.
+
+ The ``outparam`` can be used like a regular function parameter. The "output" value will
+ be available from the [sqlalchemy.engine#ResultProxy] object via its ``out_parameters``
+ attribute, which returns a dictionary containing the values.
+ """
+
+ return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True)
+
+def text(text, bind=None, *args, **kwargs):
"""Create literal text to be inserted into a query.
When constructing a query from a ``select()``, ``update()``,
@@ -729,9 +701,6 @@ def text(text, bind=None, engine=None, *args, **kwargs):
bind
An optional connection or engine to be used for this text query.
- engine
- deprecated. a synonym for 'bind'.
-
bindparams
A list of ``bindparam()`` instances which can be used to define
the types and/or initial values for the bind parameters within
@@ -748,7 +717,7 @@ def text(text, bind=None, engine=None, *args, **kwargs):
"""
- return _TextClause(text, engine=engine, bind=bind, *args, **kwargs)
+ return _TextClause(text, bind=bind, *args, **kwargs)
def null():
"""Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement."""
@@ -786,30 +755,44 @@ def _compound_select(keyword, *selects, **kwargs):
def _is_literal(element):
return not isinstance(element, ClauseElement)
-def _literals_as_text(element):
- if _is_literal(element):
+def _literal_as_text(element):
+ if isinstance(element, Operators):
+ return element.expression_element()
+ elif _is_literal(element):
return _TextClause(unicode(element))
else:
return element
-def _literals_as_binds(element, name='literal', type=None):
- if _is_literal(element):
+def _literal_as_column(element):
+ if isinstance(element, Operators):
+ return element.clause_element()
+ elif _is_literal(element):
+ return literal_column(str(element))
+ else:
+ return element
+
+def _literal_as_binds(element, name='literal', type_=None):
+ if isinstance(element, Operators):
+ return element.expression_element()
+ elif _is_literal(element):
if element is None:
return null()
else:
- return _BindParamClause(name, element, shortname=name, type=type, unique=True)
+ return _BindParamClause(name, element, shortname=name, type_=type_, unique=True)
else:
return element
+
+def _selectable(element):
+ if hasattr(element, '__selectable__'):
+ return element.__selectable__()
+ elif isinstance(element, Selectable):
+ return element
+ else:
+ raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
def is_column(col):
return isinstance(col, ColumnElement)
-class AbstractDialect(object):
- """Represent the behavior of a particular database.
-
- Used by ``Compiled`` objects."""
- pass
-
class ClauseParameters(object):
"""Represent a dictionary/iterator of bind parameter key names/values.
@@ -822,52 +805,51 @@ class ClauseParameters(object):
def __init__(self, dialect, positional=None):
super(ClauseParameters, self).__init__()
self.dialect = dialect
- self.binds = {}
- self.binds_to_names = {}
- self.binds_to_values = {}
+ self.__binds = {}
self.positional = positional or []
+ def get_parameter(self, key):
+ return self.__binds[key]
+
def set_parameter(self, bindparam, value, name):
- self.binds[bindparam.key] = bindparam
- self.binds[name] = bindparam
- self.binds_to_names[bindparam] = name
- self.binds_to_values[bindparam] = value
+ self.__binds[name] = [bindparam, name, value]
def get_original(self, key):
- """Return the given parameter as it was originally placed in
- this ``ClauseParameters`` object, without any ``Type``
- conversion."""
- return self.binds_to_values[self.binds[key]]
+ return self.__binds[key][2]
+
+ def get_type(self, key):
+ return self.__binds[key][0].type
def get_processed(self, key):
- bind = self.binds[key]
- value = self.binds_to_values[bind]
+ (bind, name, value) = self.__binds[key]
return bind.typeprocess(value, self.dialect)
def keys(self):
- return self.binds_to_names.values()
+ return self.__binds.keys()
+
+ def __iter__(self):
+ return iter(self.keys())
def __getitem__(self, key):
return self.get_processed(key)
def __contains__(self, key):
- return key in self.binds
+ return key in self.__binds
def set_value(self, key, value):
- bind = self.binds[key]
- self.binds_to_values[bind] = value
+ self.__binds[key][2] = value
def get_original_dict(self):
- return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()])
+ return dict([(name, value) for (b, name, value) in self.__binds.values()])
def get_raw_list(self):
return [self.get_processed(key) for key in self.positional]
- def get_raw_dict(self):
- d = {}
- for k in self.binds_to_names.values():
- d[k] = self.get_processed(k)
- return d
+ def get_raw_dict(self, encode_keys=False):
+ if encode_keys:
+ return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()])
+ else:
+ return dict([(key, self.get_processed(key)) for key in self.keys()])
def __repr__(self):
return self.__class__.__name__ + ":" + repr(self.get_original_dict())
@@ -876,8 +858,8 @@ class ClauseVisitor(object):
"""A class that knows how to traverse and visit
``ClauseElements``.
- Each ``ClauseElement``'s accept_visitor() method will call a
- corresponding visit_XXXX() method here. Traversal of a
+ Calls visit_XXX() methods dynamically generated for each particualr
+ ``ClauseElement`` subclass encountered. Traversal of a
hierarchy of ``ClauseElements`` is achieved via the
``traverse()`` method, which is passed the lead
``ClauseElement``.
@@ -889,22 +871,44 @@ class ClauseVisitor(object):
these options can indicate modifications to the set of
elements returned, such as to not return column collections
(column_collections=False) or to return Schema-level items
- (schema_visitor=True)."""
+ (schema_visitor=True).
+
+ ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+ operation, which will produce a copy of a given ``ClauseElement``
+ structure while at the same time allowing ``ClauseVisitor`` subclasses
+ to modify the new structure in-place.
+
+ """
__traverse_options__ = {}
- def traverse(self, obj, stop_on=None):
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
+
+ def traverse_single(self, obj, **kwargs):
+ meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kwargs)
+
+ def traverse(self, obj, stop_on=None, clone=False):
+ if clone:
+ obj = obj._clone()
+
+ v = self
+ visitors = []
+ while v is not None:
+ visitors.append(v)
+ v = getattr(v, '_next', None)
+
+ def _trav(obj):
+ if stop_on is not None and obj in stop_on:
+ return
+ if clone:
+ obj._copy_internals()
+ for c in obj.get_children(**self.__traverse_options__):
+ _trav(c)
+
+ for v in visitors:
+ meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ meth(obj)
+ _trav(obj)
return obj
def chain(self, visitor):
@@ -916,78 +920,6 @@ class ClauseVisitor(object):
tail = tail._next
tail._next = visitor
return self
-
- def visit_column(self, column):
- pass
- def visit_table(self, table):
- pass
- def visit_fromclause(self, fromclause):
- pass
- def visit_bindparam(self, bindparam):
- pass
- def visit_textclause(self, textclause):
- pass
- def visit_compound(self, compound):
- pass
- def visit_compound_select(self, compound):
- pass
- def visit_binary(self, binary):
- pass
- def visit_unary(self, unary):
- pass
- def visit_alias(self, alias):
- pass
- def visit_select(self, select):
- pass
- def visit_join(self, join):
- pass
- def visit_null(self, null):
- pass
- def visit_clauselist(self, list):
- pass
- def visit_calculatedclause(self, calcclause):
- pass
- def visit_grouping(self, gr):
- pass
- def visit_function(self, func):
- pass
- def visit_cast(self, cast):
- pass
- def visit_label(self, label):
- pass
- def visit_typeclause(self, typeclause):
- pass
-
-class LoggingClauseVisitor(ClauseVisitor):
- """extends ClauseVisitor to include debug logging of all traversal.
-
- To install this visitor, set logging.DEBUG for
- 'sqlalchemy.sql.ClauseVisitor' **before** you import the
- sqlalchemy.sql module.
- """
-
- def traverse(self, obj, stop_on=None):
- stack = [(obj, "")]
- traversal = []
- while len(stack) > 0:
- (t, indent) = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, (t, indent))
- for c in t.get_children(**self.__traverse_options__):
- stack.append((c, indent + " "))
-
- for (target, indent) in traversal:
- self.logger.debug(indent + repr(target))
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
- return obj
-
-LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor)
-
-if logging.is_debug_enabled(LoggingClauseVisitor.logger):
- ClauseVisitor=LoggingClauseVisitor
class NoColumnVisitor(ClauseVisitor):
"""a ClauseVisitor that will not traverse the exported Column
@@ -1000,113 +932,35 @@ class NoColumnVisitor(ClauseVisitor):
"""
__traverse_options__ = {'column_collections':False}
-
-class Executor(object):
- """Interface representing a "thing that can produce Compiled objects
- and execute them"."""
-
- def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
- """Execute a Compiled object."""
-
- raise NotImplementedError()
-
- def compiler(self, statement, parameters, **kwargs):
- """Return a Compiled object for the given statement and parameters."""
-
- raise NotImplementedError()
-
-class Compiled(ClauseVisitor):
- """Represent a compiled SQL expression.
-
- The ``__str__`` method of the ``Compiled`` object should produce
- the actual text of the statement. ``Compiled`` objects are
- specific to their underlying database dialect, and also may
- or may not be specific to the columns referenced within a
- particular set of bind parameters. In no case should the
- ``Compiled`` object be dependent on the actual values of those
- bind parameters, even though it may reference those values as
- defaults.
- """
-
- def __init__(self, dialect, statement, parameters, bind=None, engine=None):
- """Construct a new ``Compiled`` object.
-
- statement
- ``ClauseElement`` to be compiled.
-
- parameters
- Optional dictionary indicating a set of bind parameters
- specified with this ``Compiled`` object. These parameters
- are the *default* values corresponding to the
- ``ClauseElement``'s ``_BindParamClauses`` when the
- ``Compiled`` is executed. In the case of an ``INSERT`` or
- ``UPDATE`` statement, these parameters will also result in
- the creation of new ``_BindParamClause`` objects for each
- key and will also affect the generated column list in an
- ``INSERT`` statement and the ``SET`` clauses of an
- ``UPDATE`` statement. The keys of the parameter dictionary
- can either be the string names of columns or
- ``_ColumnClause`` objects.
-
- bind
- optional engine or connection which will be bound to the
- compiled object.
-
- engine
- deprecated, a synonym for 'bind'
- """
- self.dialect = dialect
- self.statement = statement
- self.parameters = parameters
- self.bind = bind or engine
- self.can_execute = statement.supports_execution()
-
- def compile(self):
- self.traverse(self.statement)
- self.after_compile()
-
- def __str__(self):
- """Return the string text of the generated SQL statement."""
-
- raise NotImplementedError()
- def get_params(self, **params):
- """Deprecated. use construct_params(). (supports unicode names)
- """
- return self.construct_params(params)
-
- def construct_params(self, params):
- """Return the bind params for this compiled object.
-
- Will start with the default parameters specified when this
- ``Compiled`` object was first constructed, and will override
- those values with those sent via `**params`, which are
- key/value pairs. Each key should match one of the
- ``_BindParamClause`` objects compiled into this object; either
- the `key` or `shortname` property of the ``_BindParamClause``.
- """
- raise NotImplementedError()
+class _FigureVisitName(type):
+ def __init__(cls, clsname, bases, dict):
+ if not '__visit_name__' in cls.__dict__:
+ m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+ x = m.group(1)
+ x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+ cls.__visit_name__ = x.lower()
+ super(_FigureVisitName, cls).__init__(clsname, bases, dict)
- def execute(self, *multiparams, **params):
- """Execute this compiled object."""
-
- e = self.bind
- if e is None:
- raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
- return e.execute_compiled(self, *multiparams, **params)
-
- def scalar(self, *multiparams, **params):
- """Execute this compiled object and return the result's scalar value."""
-
- return self.execute(*multiparams, **params).scalar()
-
class ClauseElement(object):
"""Base class for elements of a programmatically constructed SQL
expression.
"""
+ __metaclass__ = _FigureVisitName
+
+ def _clone(self):
+ """create a shallow copy of this ClauseElement.
+
+ This method may be used by a generative API.
+ Its also used as part of the "deep" copy afforded
+ by a traversal that combines the _copy_internals()
+ method."""
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = self.__dict__.copy()
+ return c
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
"""Return objects represented in this ``ClauseElement`` that
should be added to the ``FROM`` list of a query, when this
``ClauseElement`` is placed in the column clause of a
@@ -1115,7 +969,7 @@ class ClauseElement(object):
raise NotImplementedError(repr(self))
- def _hide_froms(self):
+ def _hide_froms(self, **modifiers):
"""Return a list of ``FROM`` clause elements which this
``ClauseElement`` replaces.
"""
@@ -1131,13 +985,14 @@ class ClauseElement(object):
return self is other
- def accept_visitor(self, visitor):
- """Accept a ``ClauseVisitor`` and call the appropriate
- ``visit_xxx`` method.
- """
-
- raise NotImplementedError(repr(self))
-
+ def _copy_internals(self):
+ """reassign internal elements to be clones of themselves.
+
+ called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy."""
+
+ pass
+
def get_children(self, **kwargs):
"""return immediate child elements of this ``ClauseElement``.
@@ -1160,18 +1015,6 @@ class ClauseElement(object):
return False
- def copy_container(self):
- """Return a copy of this ``ClauseElement``, if this
- ``ClauseElement`` contains other ``ClauseElements``.
-
- If this ``ClauseElement`` is not a container, it should return
- self. This is used to create copies of expression trees that
- still reference the same *leaf nodes*. The new structure can
- then be restructured without affecting the original.
- """
-
- return self
-
def _find_engine(self):
"""Default strategy for locating an engine within the clause element.
@@ -1195,7 +1038,6 @@ class ClauseElement(object):
return None
bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""")
- engine = bind
def execute(self, *multiparams, **params):
"""Compile and execute this ``ClauseElement``."""
@@ -1213,7 +1055,7 @@ class ClauseElement(object):
return self.execute(*multiparams, **params).scalar()
- def compile(self, bind=None, engine=None, parameters=None, compiler=None, dialect=None):
+ def compile(self, bind=None, parameters=None, compiler=None, dialect=None):
"""Compile this SQL expression.
Uses the given ``Compiler``, or the given ``AbstractDialect``
@@ -1236,7 +1078,7 @@ class ClauseElement(object):
``SET`` and ``VALUES`` clause of those statements.
"""
- if (isinstance(parameters, list) or isinstance(parameters, tuple)):
+ if isinstance(parameters, (list, tuple)):
parameters = parameters[0]
if compiler is None:
@@ -1244,8 +1086,6 @@ class ClauseElement(object):
compiler = dialect.compiler(self, parameters)
elif bind is not None:
compiler = bind.compiler(self, parameters)
- elif engine is not None:
- compiler = engine.compiler(self, parameters)
elif self.bind is not None:
compiler = self.bind.compiler(self, parameters)
@@ -1268,49 +1108,257 @@ class ClauseElement(object):
return self._negate()
def _negate(self):
- return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
+ if hasattr(self, 'negation_clause'):
+ return self.negation_clause
+ else:
+ return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None)
+
-class _CompareMixin(object):
- """Defines comparison operations for ``ClauseElement`` instances.
+class Operators(object):
+ def from_():
+ raise NotImplementedError()
+ from_ = staticmethod(from_)
- This is a mixin class that adds the capability to produce ``ClauseElement``
- instances based on regular Python operators.
- These operations are achieved using Python's operator overload methods
- (i.e. ``__eq__()``, ``__ne__()``, etc.
+ def as_():
+ raise NotImplementedError()
+ as_ = staticmethod(as_)
- Overridden operators include all comparison operators (i.e. '==', '!=', '<'),
- math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate
- to ``AND`` and ``OR`` respectively.
+ def exists():
+ raise NotImplementedError()
+ exists = staticmethod(exists)
- Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``,
- ``DISTINCT``, etc.
+ def is_():
+ raise NotImplementedError()
+ is_ = staticmethod(is_)
- """
+ def isnot():
+ raise NotImplementedError()
+ isnot = staticmethod(isnot)
+
+ def __and__(self, other):
+ return self.operate(operator.and_, other)
+
+ def __or__(self, other):
+ return self.operate(operator.or_, other)
+
+ def __invert__(self):
+ return self.operate(operator.inv)
+
+ def clause_element(self):
+ raise NotImplementedError()
+
+ def operate(self, op, *other, **kwargs):
+ raise NotImplementedError()
+
+ def reverse_operate(self, op, *other, **kwargs):
+ raise NotImplementedError()
+
+class ColumnOperators(Operators):
+ """defines comparison and math operations"""
+
+ def like_op(a, b):
+ return a.like(b)
+ like_op = staticmethod(like_op)
+
+ def notlike_op(a, b):
+ raise NotImplementedError()
+ notlike_op = staticmethod(notlike_op)
+ def ilike_op(a, b):
+ return a.ilike(b)
+ ilike_op = staticmethod(ilike_op)
+
+ def notilike_op(a, b):
+ raise NotImplementedError()
+ notilike_op = staticmethod(notilike_op)
+
+ def between_op(a, b):
+ return a.between(b)
+ between_op = staticmethod(between_op)
+
+ def in_op(a, b):
+ return a.in_(*b)
+ in_op = staticmethod(in_op)
+
+ def notin_op(a, b):
+ raise NotImplementedError()
+ notin_op = staticmethod(notin_op)
+
+ def startswith_op(a, b):
+ return a.startswith(b)
+ startswith_op = staticmethod(startswith_op)
+
+ def endswith_op(a, b):
+ return a.endswith(b)
+ endswith_op = staticmethod(endswith_op)
+
+ def comma_op(a, b):
+ raise NotImplementedError()
+ comma_op = staticmethod(comma_op)
+
+ def concat_op(a, b):
+ return a.concat(b)
+ concat_op = staticmethod(concat_op)
+
def __lt__(self, other):
- return self._compare('<', other)
+ return self.operate(operator.lt, other)
def __le__(self, other):
- return self._compare('<=', other)
+ return self.operate(operator.le, other)
def __eq__(self, other):
- return self._compare('=', other)
+ return self.operate(operator.eq, other)
def __ne__(self, other):
- return self._compare('!=', other)
+ return self.operate(operator.ne, other)
def __gt__(self, other):
- return self._compare('>', other)
+ return self.operate(operator.gt, other)
def __ge__(self, other):
- return self._compare('>=', other)
+ return self.operate(operator.ge, other)
+ def concat(self, other):
+ return self.operate(ColumnOperators.concat_op, other)
+
def like(self, other):
- """produce a ``LIKE`` clause."""
- return self._compare('LIKE', other)
+ return self.operate(ColumnOperators.like_op, other)
+
+ def in_(self, *other):
+ return self.operate(ColumnOperators.in_op, other)
+
+ def startswith(self, other):
+ return self.operate(ColumnOperators.startswith_op, other)
+
+ def endswith(self, other):
+ return self.operate(ColumnOperators.endswith_op, other)
+
+ def __radd__(self, other):
+ return self.reverse_operate(operator.add, other)
+
+ def __rsub__(self, other):
+ return self.reverse_operate(operator.sub, other)
+
+ def __rmul__(self, other):
+ return self.reverse_operate(operator.mul, other)
+
+ def __rdiv__(self, other):
+ return self.reverse_operate(operator.div, other)
+
+ def between(self, cleft, cright):
+ return self.operate(Operators.between_op, (cleft, cright))
+
+ def __add__(self, other):
+ return self.operate(operator.add, other)
+
+ def __sub__(self, other):
+ return self.operate(operator.sub, other)
+
+ def __mul__(self, other):
+ return self.operate(operator.mul, other)
+
+ def __div__(self, other):
+ return self.operate(operator.div, other)
+
+ def __mod__(self, other):
+ return self.operate(operator.mod, other)
+
+ def __truediv__(self, other):
+ return self.operate(operator.truediv, other)
+
+# precedence ordering for common operators. if an operator is not present in this list,
+# it will be parenthesized when grouped against other operators
+_smallest = object()
+_largest = object()
+
+PRECEDENCE = {
+ Operators.from_:15,
+ operator.mul:7,
+ operator.div:7,
+ operator.mod:7,
+ operator.add:6,
+ operator.sub:6,
+ ColumnOperators.concat_op:6,
+ ColumnOperators.ilike_op:5,
+ ColumnOperators.notilike_op:5,
+ ColumnOperators.like_op:5,
+ ColumnOperators.notlike_op:5,
+ ColumnOperators.in_op:5,
+ ColumnOperators.notin_op:5,
+ Operators.is_:5,
+ Operators.isnot:5,
+ operator.eq:5,
+ operator.ne:5,
+ operator.gt:5,
+ operator.lt:5,
+ operator.ge:5,
+ operator.le:5,
+ ColumnOperators.between_op:5,
+ operator.inv:4,
+ operator.and_:3,
+ operator.or_:2,
+ ColumnOperators.comma_op:-1,
+ Operators.as_:-1,
+ Operators.exists:0,
+ _smallest: -1000,
+ _largest: 1000
+}
+
+class _CompareMixin(ColumnOperators):
+ """Defines comparison and math operations for ``ClauseElement`` instances."""
+
+ def __compare(self, op, obj, negate=None):
+ if obj is None or isinstance(obj, _Null):
+ if op == operator.eq:
+ return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot)
+ elif op == operator.ne:
+ return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_)
+ else:
+ raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+ else:
+ obj = self._check_literal(obj)
+
+
+ return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate)
+
+ def __operate(self, op, obj):
+ obj = self._check_literal(obj)
+
+ type_ = self._compare_type(obj)
+
+ # TODO: generalize operator overloading like this out into the types module
+ if op == operator.add and isinstance(type_, (sqltypes.Concatenable)):
+ op = ColumnOperators.concat_op
+
+ return _BinaryExpression(self.expression_element(), obj, op, type_=type_)
+
+ operators = {
+ operator.add : (__operate,),
+ operator.mul : (__operate,),
+ operator.sub : (__operate,),
+ operator.div : (__operate,),
+ operator.mod : (__operate,),
+ operator.truediv : (__operate,),
+ operator.lt : (__compare, operator.ge),
+ operator.le : (__compare, operator.gt),
+ operator.ne : (__compare, operator.eq),
+ operator.gt : (__compare, operator.le),
+ operator.ge : (__compare, operator.lt),
+ operator.eq : (__compare, operator.ne),
+ ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op),
+ }
+
+ def operate(self, op, other):
+ o = _CompareMixin.operators[op]
+ return o[0](self, op, other, *o[1:])
+
+ def reverse_operate(self, op, other):
+ return self._bind_param(other).operate(op, self)
def in_(self, *other):
- """produce an ``IN`` clause."""
+ return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other)
+
+ def _in_impl(self, op, negate_op, *other):
if len(other) == 0:
return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1')))
elif len(other) == 1:
@@ -1318,8 +1366,8 @@ class _CompareMixin(object):
if _is_literal(o) or isinstance( o, _CompareMixin):
return self.__eq__( o) #single item -> ==
else:
- assert hasattr( o, '_selectable') #better check?
- return self._compare( 'IN', o, negate='NOT IN') #single selectable
+ assert isinstance(o, Selectable)
+ return self.__compare( op, o, negate=negate_op) #single selectable
args = []
for o in other:
@@ -1329,29 +1377,22 @@ class _CompareMixin(object):
else:
o = self._bind_param(o)
args.append(o)
- return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
+ return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op)
def startswith(self, other):
"""produce the clause ``LIKE '<other>%'``"""
- perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String)
- return self._compare('LIKE', other + perc)
+
+ perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String)
+ return self.__compare(ColumnOperators.like_op, other + perc)
def endswith(self, other):
"""produce the clause ``LIKE '%<other>'``"""
+
if isinstance(other,(str,unicode)): po = '%' + other
else:
- po = literal('%', type= sqltypes.String) + other
- po.type = sqltypes.to_instance( sqltypes.String) #force!
- return self._compare('LIKE', po)
-
- def __radd__(self, other):
- return self._bind_param(other)._operate('+', self)
- def __rsub__(self, other):
- return self._bind_param(other)._operate('-', self)
- def __rmul__(self, other):
- return self._bind_param(other)._operate('*', self)
- def __rdiv__(self, other):
- return self._bind_param(other)._operate('/', self)
+ po = literal('%', type_=sqltypes.String) + other
+ po.type = sqltypes.to_instance(sqltypes.String) #force!
+ return self.__compare(ColumnOperators.like_op, po)
def label(self, name):
"""produce a column label, i.e. ``<columnname> AS <name>``"""
@@ -1363,7 +1404,8 @@ class _CompareMixin(object):
def between(self, cleft, cright):
"""produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
- return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN')
+
+ return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op)
def op(self, operator):
"""produce a generic operator function.
@@ -1382,59 +1424,25 @@ class _CompareMixin(object):
passed to the generated function.
"""
- return lambda other: self._operate(operator, other)
-
- # and here come the math operators:
-
- def __add__(self, other):
- return self._operate('+', other)
-
- def __sub__(self, other):
- return self._operate('-', other)
-
- def __mul__(self, other):
- return self._operate('*', other)
-
- def __div__(self, other):
- return self._operate('/', other)
-
- def __mod__(self, other):
- return self._operate('%', other)
-
- def __truediv__(self, other):
- return self._operate('/', other)
+ return lambda other: self.__operate(operator, other)
def _bind_param(self, obj):
- return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
+ return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True)
def _check_literal(self, other):
- if _is_literal(other):
+ if isinstance(other, Operators):
+ return other.expression_element()
+ elif _is_literal(other):
return self._bind_param(other)
else:
return other
+
+ def clause_element(self):
+ """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``."""
+ return self
- def _compare(self, operator, obj, negate=None):
- if obj is None or isinstance(obj, _Null):
- if operator == '=':
- return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT')
- elif operator == '!=':
- return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS')
- else:
- raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
- else:
- obj = self._check_literal(obj)
-
- return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate)
-
- def _operate(self, operator, obj):
- if _is_literal(obj):
- obj = self._bind_param(obj)
- return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
-
- def _compare_self(self):
- """Allow ``ColumnImpl`` to return its ``Column`` object for
- usage in ``ClauseElements``, all others to just return self.
- """
+ def expression_element(self):
+ """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
return self
@@ -1460,23 +1468,10 @@ class Selectable(ClauseElement):
columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""")
- def _selectable(self):
- return self
-
- def accept_visitor(self, visitor):
- raise NotImplementedError(repr(self))
-
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
- def _group_parenthesized(self):
- """Indicate if this ``Selectable`` requires parenthesis when
- grouped into a compound statement.
- """
- return True
-
-
class ColumnElement(Selectable, _CompareMixin):
"""Represent an element that is useable within the
"column clause" portion of a ``SELECT`` statement.
@@ -1616,8 +1611,10 @@ class ColumnCollection(util.OrderedProperties):
l.append(c==local)
return and_(*l)
- def __contains__(self, col):
- return self.contains_column(col)
+ def __contains__(self, other):
+ if not isinstance(other, basestring):
+ raise exceptions.ArgumentError("__contains__ requires a string argument")
+ return self.has_key(other)
def contains_column(self, col):
# have to use a Set here, because it will compare the identity
@@ -1649,19 +1646,18 @@ class FromClause(Selectable):
clause of a ``SELECT`` statement.
"""
+ __visit_name__ = 'fromclause'
+
def __init__(self, name=None):
self.name = name
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
# this could also be [self], at the moment it doesnt matter to the Select object
return []
def default_order_by(self):
return [self.oid_column]
- def accept_visitor(self, visitor):
- visitor.visit_fromclause(self)
-
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
@@ -1703,6 +1699,13 @@ class FromClause(Selectable):
FindCols().traverse(self)
return ret
+ def is_derived_from(self, fromclause):
+ """return True if this FromClause is 'derived' from the given FromClause.
+
+ An example would be an Alias of a Table is derived from that Table."""
+
+ return False
+
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
"""Given a ``ColumnElement``, return the exported
``ColumnElement`` object from this ``Selectable`` which
@@ -1730,9 +1733,10 @@ class FromClause(Selectable):
it merely shares a common anscestor with one of
the exported columns of this ``FromClause``.
"""
- if column in self.c:
+
+ if self.c.contains_column(column):
return column
-
+
if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
if not raiseerr:
return None
@@ -1761,6 +1765,15 @@ class FromClause(Selectable):
self._export_columns()
return getattr(self, name)
+ def _clone_from_clause(self):
+ # delete all the "generated" collections of columns for a newly cloned FromClause,
+ # so that they will be re-derived from the item.
+ # this is because FromClause subclasses, when cloned, need to reestablish new "proxied"
+ # columns that are linked to the new item
+ for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'):
+ if hasattr(self, attr):
+ delattr(self, attr)
+
columns = property(lambda s:s._get_exported_attribute('_columns'))
c = property(lambda s:s._get_exported_attribute('_columns'))
primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
@@ -1791,8 +1804,9 @@ class FromClause(Selectable):
self._primary_key = ColumnSet()
self._foreign_keys = util.Set()
self._orig_cols = {}
+
if columns is None:
- columns = self._adjusted_exportable_columns()
+ columns = self._flatten_exportable_columns()
for co in columns:
cp = self._proxy_column(co)
for ci in cp.orig_set:
@@ -1806,13 +1820,14 @@ class FromClause(Selectable):
for ci in self.oid_column.orig_set:
self._orig_cols[ci] = self.oid_column
- def _adjusted_exportable_columns(self):
+ def _flatten_exportable_columns(self):
"""return the list of ColumnElements represented within this FromClause's _exportable_columns"""
export = self._exportable_columns()
for column in export:
- try:
- s = column._selectable()
- except AttributeError:
+ # TODO: is this conditional needed ?
+ if isinstance(column, Selectable):
+ s = column
+ else:
continue
for co in s.columns:
yield co
@@ -1829,7 +1844,9 @@ class _BindParamClause(ClauseElement, _CompareMixin):
Public constructor is the ``bindparam()`` function.
"""
- def __init__(self, key, value, shortname=None, type=None, unique=False):
+ __visit_name__ = 'bindparam'
+
+ def __init__(self, key, value, shortname=None, type_=None, unique=False, isoutparam=False):
"""Construct a _BindParamClause.
key
@@ -1852,7 +1869,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
execution may match either the key or the shortname of the
corresponding ``_BindParamClause`` objects.
- type
+ type\_
A ``TypeEngine`` object that will be used to pre-process the
value corresponding to this ``_BindParamClause`` at
execution time.
@@ -1862,23 +1879,34 @@ class _BindParamClause(ClauseElement, _CompareMixin):
modified if another ``_BindParamClause`` of the same
name already has been located within the containing
``ClauseElement``.
+
+ isoutparam
+ if True, the parameter should be treated like a stored procedure "OUT"
+ parameter.
"""
- self.key = key
+ self.key = key or "{ANON %d param}" % id(self)
self.value = value
self.shortname = shortname or key
self.unique = unique
- self.type = sqltypes.to_instance(type)
-
- def accept_visitor(self, visitor):
- visitor.visit_bindparam(self)
-
- def _get_from_objects(self):
+ self.isoutparam = isoutparam
+ type_ = sqltypes.to_instance(type_)
+ if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map:
+ self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)])
+ else:
+ self.type = type_
+
+ # TODO: move to types module, obviously
+ type_map = {
+ str : sqltypes.String,
+ unicode : sqltypes.Unicode,
+ int : sqltypes.Integer,
+ float : sqltypes.Numeric
+ }
+
+ def _get_from_objects(self, **modifiers):
return []
- def copy_container(self):
- return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique)
-
def typeprocess(self, value, dialect):
return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
@@ -1893,7 +1921,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
def __repr__(self):
- return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+ return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
class _TypeClause(ClauseElement):
"""Handle a type keyword in a SQL statement.
@@ -1901,13 +1929,12 @@ class _TypeClause(ClauseElement):
Used by the ``Case`` statement.
"""
+ __visit_name__ = 'typeclause'
+
def __init__(self, type):
self.type = type
- def accept_visitor(self, visitor):
- visitor.visit_typeclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class _TextClause(ClauseElement):
@@ -1916,8 +1943,10 @@ class _TextClause(ClauseElement):
Public constructor is the ``text()`` function.
"""
- def __init__(self, text = "", bind=None, engine=None, bindparams=None, typemap=None):
- self._bind = bind or engine
+ __visit_name__ = 'textclause'
+
+ def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
+ self._bind = bind
self.bindparams = {}
self.typemap = typemap
if typemap is not None:
@@ -1930,7 +1959,7 @@ class _TextClause(ClauseElement):
# scan the string and search for bind parameter names, add them
# to the list of bindparams
- self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text)
+ self.text = BIND_PARAMS.sub(repl, text)
if bindparams is not None:
for b in bindparams:
self.bindparams[b.key] = b
@@ -1944,13 +1973,13 @@ class _TextClause(ClauseElement):
columns = property(lambda s:[])
+ def _copy_internals(self):
+ self.bindparams = [b._clone() for b in self.bindparams]
+
def get_children(self, **kwargs):
return self.bindparams.values()
- def accept_visitor(self, visitor):
- visitor.visit_textclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
def supports_execution(self):
@@ -1965,10 +1994,7 @@ class _Null(ColumnElement):
def __init__(self):
self.type = sqltypes.NULLTYPE
- def accept_visitor(self, visitor):
- visitor.visit_null(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class ClauseList(ClauseElement):
@@ -1976,14 +2002,16 @@ class ClauseList(ClauseElement):
By default, is comma-separated, such as a column listing.
"""
-
+ __visit_name__ = 'clauselist'
+
def __init__(self, *clauses, **kwargs):
self.clauses = []
- self.operator = kwargs.pop('operator', ',')
+ self.operator = kwargs.pop('operator', ColumnOperators.comma_op)
self.group = kwargs.pop('group', True)
self.group_contents = kwargs.pop('group_contents', True)
for c in clauses:
- if c is None: continue
+ if c is None:
+ continue
self.append(c)
def __iter__(self):
@@ -1991,32 +2019,28 @@ class ClauseList(ClauseElement):
def __len__(self):
return len(self.clauses)
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return ClauseList(operator=self.operator, *clauses)
-
def append(self, clause):
# TODO: not sure if i like the 'group_contents' flag. need to define the difference between
# a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ?
if self.group_contents:
- self.clauses.append(_literals_as_text(clause).self_group(against=self.operator))
+ self.clauses.append(_literal_as_text(clause).self_group(against=self.operator))
else:
- self.clauses.append(_literals_as_text(clause))
+ self.clauses.append(_literal_as_text(clause))
+
+ def _copy_internals(self):
+ self.clauses = [clause._clone() for clause in self.clauses]
def get_children(self, **kwargs):
return self.clauses
- def accept_visitor(self, visitor):
- visitor.visit_clauselist(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
f = []
for c in self.clauses:
- f += c._get_from_objects()
+ f += c._get_from_objects(**modifiers)
return f
def self_group(self, against=None):
- if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
@@ -2043,40 +2067,45 @@ class _CalculatedClause(ColumnElement):
Extends ``ColumnElement`` to provide column-level comparison
operators.
"""
-
+ __visit_name__ = 'calculatedclause'
+
def __init__(self, name, *clauses, **kwargs):
self.name = name
- self.type = sqltypes.to_instance(kwargs.get('type', None))
- self._bind = kwargs.get('bind', kwargs.get('engine', None))
+ self.type = sqltypes.to_instance(kwargs.get('type_', None))
+ self._bind = kwargs.get('bind', None)
self.group = kwargs.pop('group', True)
- self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
+ clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
if self.group:
- self.clause_expr = self.clauses.self_group()
+ self.clause_expr = clauses.self_group()
else:
- self.clause_expr = self.clauses
+ self.clause_expr = clauses
key = property(lambda self:self.name or "_calc_")
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _CalculatedClause(type=self.type, bind=self._bind, *clauses)
-
+ def _copy_internals(self):
+ self.clause_expr = self.clause_expr._clone()
+
+ def clauses(self):
+ if isinstance(self.clause_expr, _Grouping):
+ return self.clause_expr.elem
+ else:
+ return self.clause_expr
+ clauses = property(clauses)
+
def get_children(self, **kwargs):
return self.clause_expr,
- def accept_visitor(self, visitor):
- visitor.visit_calculatedclause(self)
- def _get_from_objects(self):
- return self.clauses._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clauses._get_from_objects(**modifiers)
def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, type=self.type, unique=True)
+ return _BindParamClause(self.name, obj, type_=self.type, unique=True)
def select(self):
return select([self])
def scalar(self):
- return select([self]).scalar()
+ return select([self]).execute().scalar()
def execute(self):
return select([self]).execute()
@@ -2092,28 +2121,26 @@ class _Function(_CalculatedClause, FromClause):
"""
def __init__(self, name, *clauses, **kwargs):
- self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
- kwargs['operator'] = ','
- self._engine = kwargs.get('engine', None)
+ kwargs['operator'] = ColumnOperators.comma_op
_CalculatedClause.__init__(self, name, **kwargs)
for c in clauses:
self.append(c)
key = property(lambda self:self.name)
+ def _copy_internals(self):
+ _CalculatedClause._copy_internals(self)
+ self._clone_from_clause()
- def append(self, clause):
- self.clauses.append(_literals_as_binds(clause, self.name))
-
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _Function(self.name, type=self.type, packagenames=self.packagenames, bind=self._bind, *clauses)
+ def get_children(self, **kwargs):
+ return _CalculatedClause.get_children(self, **kwargs)
- def accept_visitor(self, visitor):
- visitor.visit_function(self)
+ def append(self, clause):
+ self.clauses.append(_literal_as_binds(clause, self.name))
class _Cast(ColumnElement):
+
def __init__(self, clause, totype, **kwargs):
if not hasattr(clause, 'label'):
clause = literal(clause)
@@ -2122,17 +2149,19 @@ class _Cast(ColumnElement):
self.typeclause = _TypeClause(self.type)
self._distance = 0
+ def _copy_internals(self):
+ self.clause = self.clause._clone()
+ self.typeclause = self.typeclause._clone()
+
def get_children(self, **kwargs):
return self.clause, self.typeclause
- def accept_visitor(self, visitor):
- visitor.visit_cast(self)
- def _get_from_objects(self):
- return self.clause._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clause._get_from_objects(**modifiers)
def _make_proxy(self, selectable, name=None):
if name is not None:
- co = _ColumnClause(name, selectable, type=self.type)
+ co = _ColumnClause(name, selectable, type_=self.type)
co._distance = self._distance + 1
co.orig_set = self.orig_set
selectable.columns[name]= co
@@ -2142,26 +2171,23 @@ class _Cast(ColumnElement):
class _UnaryExpression(ColumnElement):
- def __init__(self, element, operator=None, modifier=None, type=None, negate=None):
+ def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
self.operator = operator
self.modifier = modifier
- self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier)
- self.type = sqltypes.to_instance(type)
+ self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate)
+ def _get_from_objects(self, **modifiers):
+ return self.element._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.element._get_from_objects()
+ def _copy_internals(self):
+ self.element = self.element._clone()
def get_children(self, **kwargs):
return self.element,
- def accept_visitor(self, visitor):
- visitor.visit_unary(self)
-
def compare(self, other):
"""Compare this ``_UnaryExpression`` against the given ``ClauseElement``."""
@@ -2170,14 +2196,15 @@ class _UnaryExpression(ColumnElement):
self.modifier == other.modifier and
self.element.compare(other.element)
)
+
def _negate(self):
if self.negate is not None:
- return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
+ return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type)
else:
return super(_UnaryExpression, self)._negate()
def self_group(self, against):
- if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
@@ -2186,25 +2213,23 @@ class _UnaryExpression(ColumnElement):
class _BinaryExpression(ColumnElement):
"""Represent an expression that is ``LEFT <operator> RIGHT``."""
- def __init__(self, left, right, operator, type=None, negate=None):
- self.left = _literals_as_text(left).self_group(against=operator)
- self.right = _literals_as_text(right).self_group(against=operator)
+ def __init__(self, left, right, operator, type_=None, negate=None):
+ self.left = _literal_as_text(left).self_group(against=operator)
+ self.right = _literal_as_text(right).self_group(against=operator)
self.operator = operator
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
+ def _get_from_objects(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _copy_internals(self):
+ self.left = self.left._clone()
+ self.right = self.right._clone()
def get_children(self, **kwargs):
return self.left, self.right
- def accept_visitor(self, visitor):
- visitor.visit_binary(self)
-
def compare(self, other):
"""Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
@@ -2213,7 +2238,7 @@ class _BinaryExpression(ColumnElement):
(
self.left.compare(other.left) and self.right.compare(other.right)
or (
- self.operator in ['=', '!=', '+', '*'] and
+ self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and
self.left.compare(other.right) and self.right.compare(other.left)
)
)
@@ -2221,25 +2246,27 @@ class _BinaryExpression(ColumnElement):
def self_group(self, against=None):
# use small/large defaults for comparison so that unknown operators are always parenthesized
- if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])):
+ if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])):
return _Grouping(self)
else:
return self
def _negate(self):
if self.negate is not None:
- return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type)
+ return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type)
else:
return super(_BinaryExpression, self)._negate()
class _Exists(_UnaryExpression):
+ __visit_name__ = _UnaryExpression.__visit_name__
+
def __init__(self, *args, **kwargs):
kwargs['correlate'] = True
s = select(*args, **kwargs).self_group()
- _UnaryExpression.__init__(self, s, operator="EXISTS")
+ _UnaryExpression.__init__(self, s, operator=Operators.exists)
- def _hide_froms(self):
- return self._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self._get_from_objects(**modifiers)
class Join(FromClause):
"""represent a ``JOIN`` construct between two ``FromClause``
@@ -2251,8 +2278,8 @@ class Join(FromClause):
"""
def __init__(self, left, right, onclause=None, isouter = False):
- self.left = left._selectable()
- self.right = right._selectable()
+ self.left = _selectable(left)
+ self.right = _selectable(right).self_group()
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
else:
@@ -2265,8 +2292,8 @@ class Join(FromClause):
encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace'))
def _init_primary_key(self):
- pkcol = util.Set([c for c in self._adjusted_exportable_columns() if c.primary_key])
-
+ pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key])
+
equivs = {}
def add_equiv(a, b):
for x, y in ((a, b), (b, a)):
@@ -2277,7 +2304,7 @@ class Join(FromClause):
class BinaryVisitor(ClauseVisitor):
def visit_binary(self, binary):
- if binary.operator == '=':
+ if binary.operator == operator.eq:
add_equiv(binary.left, binary.right)
BinaryVisitor().traverse(self.onclause)
@@ -2294,9 +2321,12 @@ class Join(FromClause):
omit.add(p)
p = c
- self.__primary_key = ColumnSet([c for c in self._adjusted_exportable_columns() if c.primary_key and c not in omit])
+ self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit])
primary_key = property(lambda s:s.__primary_key)
+
+ def self_group(self, against=None):
+ return _Grouping(self)
def _locate_oid_column(self):
return self.left.oid_column
@@ -2310,6 +2340,17 @@ class Join(FromClause):
self._foreign_keys.add(f)
return column
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self.left = self.left._clone()
+ self.right = self.right._clone()
+ self.onclause = self.onclause._clone()
+ self.__folded_equivalents = None
+ self._init_primary_key()
+
+ def get_children(self, **kwargs):
+ return self.left, self.right, self.onclause
+
def _match_primaries(self, primary, secondary):
crit = []
constraints = util.Set()
@@ -2338,9 +2379,6 @@ class Join(FromClause):
else:
return and_(*crit)
- def _group_parenthesized(self):
- return True
-
def _get_folded_equivalents(self, equivs=None):
if self.__folded_equivalents is not None:
return self.__folded_equivalents
@@ -2348,7 +2386,7 @@ class Join(FromClause):
equivs = util.Set()
class LocateEquivs(NoColumnVisitor):
def visit_binary(self, binary):
- if binary.operator == '=' and binary.left.name == binary.right.name:
+ if binary.operator == operator.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
LocateEquivs().traverse(self.onclause)
@@ -2401,13 +2439,7 @@ class Join(FromClause):
return select(collist, whereclause, from_obj=[self], **kwargs)
- def get_children(self, **kwargs):
- return self.left, self.right, self.onclause
-
- def accept_visitor(self, visitor):
- visitor.visit_join(self)
-
- engine = property(lambda s:s.left.engine or s.right.engine)
+ bind = property(lambda s:s.left.bind or s.right.bind)
def alias(self, name=None):
"""Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it.
@@ -2417,11 +2449,11 @@ class Join(FromClause):
return self.select(use_labels=True, correlate=False).alias(name)
- def _hide_froms(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
class Alias(FromClause):
"""represent an alias, as typically applied to any
@@ -2443,15 +2475,22 @@ class Alias(FromClause):
if alias is None:
if self.original.named_with_column():
alias = getattr(self.original, 'name', None)
- if alias is None:
- alias = 'anon'
- elif len(alias) > 15:
- alias = alias[0:15]
- alias = alias + "_" + hex(random.randint(0, 65535))[2:]
+ alias = '{ANON %d %s}' % (id(self), alias or 'anon')
self.name = alias
self.encodedname = alias.encode('ascii', 'backslashreplace')
self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
+ def is_derived_from(self, fromclause):
+ x = self.selectable
+ while True:
+ if x is fromclause:
+ return True
+ if isinstance(x, Alias):
+ x = x.selectable
+ else:
+ break
+ return False
+
def supports_execution(self):
return self.original.supports_execution()
@@ -2468,46 +2507,57 @@ class Alias(FromClause):
#return self.selectable._exportable_columns()
return self.selectable.columns
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self.selectable = self.selectable._clone()
+ baseselectable = self.selectable
+ while isinstance(baseselectable, Alias):
+ baseselectable = baseselectable.selectable
+ self.original = baseselectable
+
def get_children(self, **kwargs):
for c in self.c:
yield c
yield self.selectable
- def accept_visitor(self, visitor):
- visitor.visit_alias(self)
-
def _get_from_objects(self):
return [self]
- def _group_parenthesized(self):
- return False
-
bind = property(lambda s: s.selectable.bind)
- engine = bind
-class _Grouping(ColumnElement):
+class _ColumnElementAdapter(ColumnElement):
+ """adapts a ClauseElement which may or may not be a
+ ColumnElement subclass itself into an object which
+ acts like a ColumnElement.
+ """
+
def __init__(self, elem):
self.elem = elem
self.type = getattr(elem, 'type', None)
-
+ self.orig_set = getattr(elem, 'orig_set', util.Set())
+
key = property(lambda s: s.elem.key)
_label = property(lambda s: s.elem._label)
- orig_set = property(lambda s:s.elem.orig_set)
-
- def copy_container(self):
- return _Grouping(self.elem.copy_container())
-
- def accept_visitor(self, visitor):
- visitor.visit_grouping(self)
+ columns = c = property(lambda s:s.elem.columns)
+
+ def _copy_internals(self):
+ self.elem = self.elem._clone()
+
def get_children(self, **kwargs):
return self.elem,
- def _hide_froms(self):
- return self.elem._hide_froms()
- def _get_from_objects(self):
- return self.elem._get_from_objects()
+
+ def _hide_froms(self, **modifiers):
+ return self.elem._hide_froms(**modifiers)
+
+ def _get_from_objects(self, **modifiers):
+ return self.elem._get_from_objects(**modifiers)
+
def __getattr__(self, attr):
return getattr(self.elem, attr)
-
+
+class _Grouping(_ColumnElementAdapter):
+ pass
+
class _Label(ColumnElement):
"""represent a label, as typically applied to any column-level element
using the ``AS`` sql keyword.
@@ -2518,32 +2568,33 @@ class _Label(ColumnElement):
"""
- def __init__(self, name, obj, type=None):
- self.name = name
+ def __init__(self, name, obj, type_=None):
while isinstance(obj, _Label):
obj = obj.obj
- self.obj = obj.self_group(against='AS')
+ self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
+
+ self.obj = obj.self_group(against=Operators.as_)
self.case_sensitive = getattr(obj, "case_sensitive", True)
- self.type = sqltypes.to_instance(type or getattr(obj, 'type', None))
+ self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
key = property(lambda s: s.name)
_label = property(lambda s: s.name)
orig_set = property(lambda s:s.obj.orig_set)
- def _compare_self(self):
+ def expression_element(self):
return self.obj
-
+
+ def _copy_internals(self):
+ self.obj = self.obj._clone()
+
def get_children(self, **kwargs):
return self.obj,
- def accept_visitor(self, visitor):
- visitor.visit_label(self)
+ def _get_from_objects(self, **modifiers):
+ return self.obj._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.obj._get_from_objects()
-
- def _hide_froms(self):
- return self.obj._hide_froms()
+ def _hide_froms(self, **modifiers):
+ return self.obj._hide_froms(**modifiers)
def _make_proxy(self, selectable, name = None):
if isinstance(self.obj, Selectable):
@@ -2551,8 +2602,6 @@ class _Label(ColumnElement):
else:
return column(self.name)._make_proxy(selectable=selectable)
-legal_characters = util.Set(string.ascii_letters + string.digits + '_')
-
class _ColumnClause(ColumnElement):
"""Represents a generic column expression from any textual string.
This includes columns associated with tables, aliases and select
@@ -2584,17 +2633,21 @@ class _ColumnClause(ColumnElement):
"""
- def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
+ def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False):
self.key = self.name = text
self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name
self.table = selectable
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self._is_oid = _is_oid
self._distance = 0
self.__label = None
self.case_sensitive = case_sensitive
self.is_literal = is_literal
-
+
+ def _clone(self):
+ # ColumnClause is immutable
+ return self
+
def _get_label(self):
"""Generate a 'label' for this column.
@@ -2617,7 +2670,6 @@ class _ColumnClause(ColumnElement):
counter += 1
else:
self.__label = self.name
- self.__label = "".join([x for x in self.__label if x in legal_characters])
return self.__label
is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name)
@@ -2632,23 +2684,20 @@ class _ColumnClause(ColumnElement):
else:
return super(_ColumnClause, self).label(name)
- def accept_visitor(self, visitor):
- visitor.visit_column(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
if self.table is not None:
return [self.table]
else:
return []
def _bind_param(self, obj):
- return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True)
+ return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True)
def _make_proxy(self, selectable, name = None):
# propigate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
- c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal)
+ c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
c.orig_set = self.orig_set
c._distance = self._distance + 1
if not self._is_oid:
@@ -2658,9 +2707,6 @@ class _ColumnClause(ColumnElement):
def _compare_type(self, obj):
return self.type
- def _group_parenthesized(self):
- return False
-
class TableClause(FromClause):
"""represents a "table" construct.
@@ -2677,6 +2723,10 @@ class TableClause(FromClause):
self._oid_column = _ColumnClause('oid', self, _is_oid=True)
self._export_columns(columns)
+ def _clone(self):
+ # TableClause is immutable
+ return self
+
def named_with_column(self):
return True
@@ -2709,15 +2759,9 @@ class TableClause(FromClause):
else:
return []
- def accept_visitor(self, visitor):
- visitor.visit_table(self)
-
def _exportable_columns(self):
raise NotImplementedError()
- def _group_parenthesized(self):
- return False
-
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
@@ -2746,68 +2790,120 @@ class TableClause(FromClause):
def delete(self, whereclause = None):
return delete(self, whereclause)
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return [self]
+
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
+ def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None):
+ self.use_labels = use_labels
+ self.for_update = for_update
+ self._limit = limit
+ self._offset = offset
+ self._bind = bind
+
+ self.append_order_by(*util.to_list(order_by, []))
+ self.append_group_by(*util.to_list(group_by, []))
+
+ def as_scalar(self):
+ return _ScalarSelect(self)
+
+ def label(self, name):
+ return self.as_scalar().label(name)
+
def supports_execution(self):
return True
+ def _generate(self):
+ s = self._clone()
+ s._clone_from_clause()
+ return s
+
+ def limit(self, limit):
+ s = self._generate()
+ s._limit = limit
+ return s
+
+ def offset(self, offset):
+ s = self._generate()
+ s._offset = offset
+ return s
+
def order_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.order_by_clause = ClauseList()
- elif getattr(self, 'order_by_clause', None):
- self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses)))
- else:
- self.order_by_clause = ClauseList(*clauses)
+ s = self._generate()
+ s.append_order_by(*clauses)
+ return s
def group_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.group_by_clause = ClauseList()
- elif getattr(self, 'group_by_clause', None):
- self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses)))
+ s = self._generate()
+ s.append_group_by(*clauses)
+ return s
+
+ def append_order_by(self, *clauses):
+ if clauses == [None]:
+ self._order_by_clause = ClauseList()
else:
- self.group_by_clause = ClauseList(*clauses)
+ if getattr(self, '_order_by_clause', None):
+ clauses = list(self._order_by_clause) + list(clauses)
+ self._order_by_clause = ClauseList(*clauses)
+ def append_group_by(self, *clauses):
+ if clauses == [None]:
+ self._group_by_clause = ClauseList()
+ else:
+ if getattr(self, '_group_by_clause', None):
+ clauses = list(self._group_by_clause) + list(clauses)
+ self._group_by_clause = ClauseList(*clauses)
+
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
- def _get_from_objects(self):
- if self.is_where or self.is_scalar:
+ def _get_from_objects(self, is_where=False, **modifiers):
+ if is_where:
return []
else:
return [self]
+class _ScalarSelect(_Grouping):
+ __visit_name__ = 'grouping'
+
+ def __init__(self, elem):
+ super(_ScalarSelect, self).__init__(elem)
+ self.type = list(elem.inner_columns)[0].type
+
+ columns = property(lambda self:[self])
+
+ def self_group(self, **kwargs):
+ return self
+
+ def _make_proxy(self, selectable, name):
+ return list(self.inner_columns)[0]._make_proxy(selectable, name)
+
+ def _get_from_objects(self, **modifiers):
+ return []
+
class CompoundSelect(_SelectBaseMixin, FromClause):
def __init__(self, keyword, *selects, **kwargs):
- _SelectBaseMixin.__init__(self)
+ self._should_correlate = kwargs.pop('correlate', False)
self.keyword = keyword
- self.use_labels = kwargs.pop('use_labels', False)
- self.should_correlate = kwargs.pop('correlate', False)
- self.for_update = kwargs.pop('for_update', False)
- self.nowait = kwargs.pop('nowait', False)
- self.limit = kwargs.pop('limit', None)
- self.offset = kwargs.pop('offset', None)
- self.is_compound = True
- self.is_where = False
- self.is_scalar = False
- self.is_subquery = False
-
- # unions group from left to right, so don't group first select
- self.selects = [n and select.self_group(self) or select for n,select in enumerate(selects)]
+ self.selects = []
# some DBs do not like ORDER BY in the inner queries of a UNION, etc.
- for s in selects:
- s.order_by(None)
+ for n, s in enumerate(selects):
+ if len(s._order_by_clause):
+ s = s.order_by(None)
+ # unions group from left to right, so don't group first select
+ if n:
+ self.selects.append(s.self_group(self))
+ else:
+ self.selects.append(s)
- self.group_by(*kwargs.pop('group_by', [None]))
- self.order_by(*kwargs.pop('order_by', [None]))
- if len(kwargs):
- raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys()))
self._col_map = {}
+ _SelectBaseMixin.__init__(self, **kwargs)
+
name = property(lambda s:s.keyword + " statement")
def self_group(self, against=None):
@@ -2835,12 +2931,18 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
col.orig_set = colset
return col
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self._col_map = {}
+ self.selects = [s._clone() for s in self.selects]
+ for attr in ('_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+
def get_children(self, column_collections=True, **kwargs):
return (column_collections and list(self.c) or []) + \
- [self.order_by_clause, self.group_by_clause] + list(self.selects)
- def accept_visitor(self, visitor):
- visitor.visit_compound_select(self)
-
+ [self._order_by_clause, self._group_by_clause] + list(self.selects)
+
def _find_engine(self):
for s in self.selects:
e = s._find_engine()
@@ -2855,160 +2957,287 @@ class Select(_SelectBaseMixin, FromClause):
"""
- def __init__(self, columns=None, whereclause=None, from_obj=[],
- order_by=None, group_by=None, having=None,
- use_labels=False, distinct=False, for_update=False,
- engine=None, bind=None, limit=None, offset=None, scalar=False,
- correlate=True):
+ def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs):
"""construct a Select object.
The public constructor for Select is the [sqlalchemy.sql#select()] function;
see that function for argument descriptions.
"""
- _SelectBaseMixin.__init__(self)
- self.__froms = util.OrderedSet()
- self.__hide_froms = util.Set([self])
- self.use_labels = use_labels
- self.whereclause = None
- self.having = None
- self._bind = bind or engine
- self.limit = limit
- self.offset = offset
- self.for_update = for_update
- self.is_compound = False
-
- # indicates that this select statement should not expand its columns
- # into the column clause of an enclosing select, and should instead
- # act like a single scalar column
- self.is_scalar = scalar
- if scalar:
- # allow corresponding_column to return None
- self.orig_set = util.Set()
-
- # indicates if this select statement, as a subquery, should automatically correlate
- # its FROM clause to that of an enclosing select, update, or delete statement.
- # note that the "correlate" method can be used to explicitly add a value to be correlated.
- self.should_correlate = correlate
-
- # indicates if this select statement is a subquery inside another query
- self.is_subquery = False
-
- # indicates if this select statement is in the from clause of another query
- self.is_selected_from = False
-
- # indicates if this select statement is a subquery as a criterion
- # inside of a WHERE clause
- self.is_where = False
+
+ self._should_correlate = correlate
+ self._distinct = distinct
- self.distinct = distinct
self._raw_columns = []
- self.__correlated = {}
- self.__correlator = Select._CorrelatedVisitor(self, False)
- self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
- self.__fromvisitor = Select._FromVisitor(self)
-
-
- self.order_by_clause = self.group_by_clause = None
+ self.__correlate = util.Set()
+ self._froms = util.OrderedSet()
+ self._whereclause = None
+ self._having = None
+ self._prefixes = []
if columns is not None:
for c in columns:
self.append_column(c)
- if order_by:
- order_by = util.to_list(order_by)
- if group_by:
- group_by = util.to_list(group_by)
- self.order_by(*(order_by or [None]))
- self.group_by(*(group_by or [None]))
- for c in self.order_by_clause:
- self.__correlator.traverse(c)
- for c in self.group_by_clause:
- self.__correlator.traverse(c)
-
- for f in from_obj:
- self.append_from(f)
-
- # whereclauses must be appended after the columns/FROM, since it affects
- # the correlation of subqueries. see test/sql/select.py SelectTest.testwheresubquery
+ if from_obj is not None:
+ for f in from_obj:
+ self.append_from(f)
+
if whereclause is not None:
self.append_whereclause(whereclause)
+
if having is not None:
self.append_having(having)
+ _SelectBaseMixin.__init__(self, **kwargs)
- class _CorrelatedVisitor(NoColumnVisitor):
- """Visit a clause, locate any ``Select`` clauses, and tell
- them that they should correlate their ``FROM`` list to that of
- their parent.
+ def _get_display_froms(self, correlation_state=None):
+ """return the full list of 'from' clauses to be displayed.
+
+ takes into account an optional 'correlation_state'
+ dictionary which contains information about this Select's
+ correlation to an enclosing select, which may cause some 'from'
+ clauses to not display in this Select's FROM clause.
+ this dictionary is generated during compile time by the
+ _calculate_correlations() method.
+
"""
+ froms = util.OrderedSet()
+ hide_froms = util.Set()
+
+ for col in self._raw_columns:
+ for f in col._hide_froms():
+ hide_froms.add(f)
+ for f in col._get_from_objects():
+ froms.add(f)
+
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
+
+ for elem in froms:
+ for f in elem._hide_froms():
+ hide_froms.add(f)
+
+ froms = froms.difference(hide_froms)
+
+ if len(froms) > 1:
+ corr = self.__correlate
+ if correlation_state is not None:
+ corr = correlation_state[self].get('correlate', util.Set()).union(corr)
+ f = froms.difference(corr)
+ if len(f) == 0:
+ raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
+ return f
+ else:
+ return froms
+
+ froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
+
+ def locate_all_froms(self):
+ froms = util.Set()
+ for col in self._raw_columns:
+ for f in col._get_from_objects():
+ froms.add(f)
+
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
+ return froms
+
+ def _calculate_correlations(self, correlation_state):
+ """generate a 'correlation_state' dictionary used by the _get_display_froms() method.
+
+ The dictionary is passed in initially empty, or already
+ containing the state information added by an enclosing
+ Select construct. The method will traverse through all
+ embedded Select statements and add information about their
+ position and "from" objects to the dictionary. Those Select
+ statements will later consult the 'correlation_state' dictionary
+ when their list of 'FROM' clauses are generated using their
+ _get_display_froms() method.
+ """
+
+ if self not in correlation_state:
+ correlation_state[self] = {}
- def __init__(self, select, is_where):
- NoColumnVisitor.__init__(self)
- self.select = select
- self.is_where = is_where
-
- def visit_compound_select(self, cs):
- self.visit_select(cs)
-
- def visit_column(self, c):
- pass
-
- def visit_table(self, c):
- pass
-
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_where = self.is_where
- select.is_subquery = True
- if not select.should_correlate:
- return
- [select.correlate(x) for x in self.select._Select__froms]
+ display_froms = self._get_display_froms(correlation_state)
+
+ class CorrelatedVisitor(NoColumnVisitor):
+ def __init__(self, is_where=False, is_column=False, is_from=False):
+ self.is_where = is_where
+ self.is_column = is_column
+ self.is_from = is_from
+
+ def visit_compound_select(self, cs):
+ self.visit_select(cs)
- class _FromVisitor(NoColumnVisitor):
- def __init__(self, select):
- NoColumnVisitor.__init__(self)
- self.select = select
+ def visit_select(s, select):
+ if select not in correlation_state:
+ correlation_state[select] = {}
+
+ if select is self:
+ return
+
+ select_state = correlation_state[select]
+ if s.is_from:
+ select_state['is_selected_from'] = True
+ if s.is_where:
+ select_state['is_where'] = True
+ select_state['is_subquery'] = True
+
+ if select._should_correlate:
+ corr = select_state.setdefault('correlate', util.Set())
+ # not crazy about this part. need to be clearer on what elements in the
+ # subquery correspond to elements in the enclosing query.
+ for f in display_froms:
+ corr.add(f)
+ for f2 in f._get_from_objects():
+ corr.add(f2)
+
+ col_vis = CorrelatedVisitor(is_column=True)
+ where_vis = CorrelatedVisitor(is_where=True)
+ from_vis = CorrelatedVisitor(is_from=True)
+
+ for col in self._raw_columns:
+ col_vis.traverse(col)
+ for f in col._get_from_objects():
+ if f is not self:
+ from_vis.traverse(f)
+
+ for col in list(self._order_by_clause) + list(self._group_by_clause):
+ col_vis.traverse(col)
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_selected_from = True
- select.is_subquery = True
+ if self._whereclause is not None:
+ where_vis.traverse(self._whereclause)
+ for f in self._whereclause._get_from_objects(is_where=True):
+ if f is not self:
+ from_vis.traverse(f)
+
+ for elem in self._froms:
+ from_vis.traverse(elem)
+
+ def _get_inner_columns(self):
+ for c in self._raw_columns:
+ if isinstance(c, Selectable):
+ for co in c.columns:
+ yield co
+ else:
+ yield c
+
+ inner_columns = property(_get_inner_columns)
+
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self._raw_columns = [c._clone() for c in self._raw_columns]
+ self._recorrelate_froms([(f, f._clone()) for f in self._froms])
+ for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+ def get_children(self, column_collections=True, **kwargs):
+ return (column_collections and list(self.columns) or []) + \
+ list(self._froms) + \
+ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+
+ def _recorrelate_froms(self, froms):
+ newcorrelate = util.Set()
+ newfroms = util.Set()
+ oldfroms = util.Set(self._froms)
+ for old, new in froms:
+ if old in self.__correlate:
+ newcorrelate.add(new)
+ self.__correlate.remove(old)
+ if old in oldfroms:
+ newfroms.add(new)
+ oldfroms.remove(old)
+ self.__correlate = self.__correlate.union(newcorrelate)
+ self._froms = [f for f in oldfroms.union(newfroms)]
+
+ def column(self, column):
+ s = self._generate()
+ s.append_column(column)
+ return s
+
+ def where(self, whereclause):
+ s = self._generate()
+ s.append_whereclause(whereclause)
+ return s
+
+ def having(self, having):
+ s = self._generate()
+ s.append_having(having)
+ return s
+
+ def distinct(self):
+ s = self._generate()
+ s.distinct = True
+ return s
+
+ def prefix_with(self, clause):
+ s = self._generate()
+ s.append_prefix(clause)
+ return s
+
+ def select_from(self, fromclause):
+ s = self._generate()
+ s.append_from(fromclause)
+ return s
+
+ def __dont_correlate(self):
+ s = self._generate()
+ s._should_correlate = False
+ return s
+
+ def correlate(self, fromclause):
+ s = self._generate()
+ s._should_correlate=False
+ if fromclause is None:
+ s.__correlate = util.Set()
+ else:
+ s.append_correlation(fromclause)
+ return s
+
+ def append_correlation(self, fromclause):
+ self.__correlate.add(fromclause)
+
def append_column(self, column):
- if _is_literal(column):
- column = literal_column(str(column))
+ column = _literal_as_column(column)
- if isinstance(column, Select) and column.is_scalar:
- column = column.self_group(against=',')
+ if isinstance(column, _ScalarSelect):
+ column = column.self_group(against=ColumnOperators.comma_op)
self._raw_columns.append(column)
-
- if self.is_scalar and not hasattr(self, 'type'):
- self.type = column.type
+
+ def append_prefix(self, clause):
+ clause = _literal_as_text(clause)
+ self._prefixes.append(clause)
- # if the column is a Select statement itself,
- # accept visitor
- self.__correlator.traverse(column)
-
- # visit the FROM objects of the column looking for more Selects
- for f in column._get_from_objects():
- if f is not self:
- self.__correlator.traverse(f)
- self._process_froms(column, False)
-
- def _make_proxy(self, selectable, name):
- if self.is_scalar:
- return self._raw_columns[0]._make_proxy(selectable, name)
+ def append_whereclause(self, whereclause):
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
-
- def label(self, name):
- if not self.is_scalar:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
+ self._whereclause = _literal_as_text(whereclause)
+
+ def append_having(self, having):
+ if self._having is not None:
+ self._having = and_(self._having, _literal_as_text(having))
else:
- return label(name, self)
+ self._having = _literal_as_text(having)
+
+ def append_from(self, fromclause):
+ if _is_literal(fromclause):
+ fromclause = FromClause(fromclause)
+ self._froms.add(fromclause)
def _exportable_columns(self):
return [c for c in self._raw_columns if isinstance(c, Selectable)]
@@ -3019,53 +3248,13 @@ class Select(_SelectBaseMixin, FromClause):
else:
return column._make_proxy(self)
- def _process_froms(self, elem, asfrom):
- for f in elem._get_from_objects():
- self.__fromvisitor.traverse(f)
- self.__froms.add(f)
- if asfrom:
- self.__froms.add(elem)
- for f in elem._hide_froms():
- self.__hide_froms.add(f)
-
def self_group(self, against=None):
if isinstance(against, CompoundSelect):
return self
return _Grouping(self)
-
- def append_whereclause(self, whereclause):
- self._append_condition('whereclause', whereclause)
-
- def append_having(self, having):
- self._append_condition('having', having)
-
- def _append_condition(self, attribute, condition):
- if isinstance(condition, basestring):
- condition = _TextClause(condition)
- self.__wherecorrelator.traverse(condition)
- self._process_froms(condition, False)
- if getattr(self, attribute) is not None:
- setattr(self, attribute, and_(getattr(self, attribute), condition))
- else:
- setattr(self, attribute, condition)
-
- def correlate(self, from_obj):
- """Given a ``FROM`` object, correlate this ``SELECT`` statement to it.
-
- This basically means the given from object will not come out
- in this select statement's ``FROM`` clause when printed.
- """
-
- self.__correlated[from_obj] = from_obj
-
- def append_from(self, fromclause):
- if isinstance(fromclause, basestring):
- fromclause = FromClause(fromclause)
- self.__correlator.traverse(fromclause)
- self._process_froms(fromclause, True)
def _locate_oid_column(self):
- for f in self.__froms:
+ for f in self.locate_all_froms():
if f is self:
# we might be in our own _froms list if a column with us as the parent is attached,
# which includes textual columns.
@@ -3076,25 +3265,6 @@ class Select(_SelectBaseMixin, FromClause):
else:
return None
- def _calc_froms(self):
- f = self.__froms.difference(self.__hide_froms)
- if (len(f) > 1):
- return f.difference(self.__correlated)
- else:
- return f
-
- froms = property(_calc_froms,
- doc="""A collection containing all elements
- of the ``FROM`` clause.""")
-
- def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.columns) or []) + \
- list(self.froms) + \
- [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None]
-
- def accept_visitor(self, visitor):
- visitor.visit_select(self)
-
def union(self, other, **kwargs):
return union(self, other, **kwargs)
@@ -3108,7 +3278,7 @@ class Select(_SelectBaseMixin, FromClause):
if self._bind is not None:
return self._bind
- for f in self.__froms:
+ for f in self._froms:
if f is self:
continue
e = f.bind
@@ -3133,20 +3303,24 @@ class _UpdateBase(ClauseElement):
def supports_execution(self):
return True
- class _SelectCorrelator(NoColumnVisitor):
- def __init__(self, table):
- NoColumnVisitor.__init__(self)
- self.table = table
-
- def visit_select(self, select):
- if select.should_correlate:
- select.correlate(self.table)
-
- def _process_whereclause(self, whereclause):
- if whereclause is not None:
- _UpdateBase._SelectCorrelator(self.table).traverse(whereclause)
- return whereclause
+ def _calculate_correlations(self, correlate_state):
+ class SelectCorrelator(NoColumnVisitor):
+ def visit_select(s, select):
+ if select._should_correlate:
+ select_state = correlate_state.setdefault(select, {})
+ corr = select_state.setdefault('correlate', util.Set())
+ corr.add(self.table)
+
+ vis = SelectCorrelator()
+ if self._whereclause is not None:
+ vis.traverse(self._whereclause)
+
+ if getattr(self, 'parameters', None) is not None:
+ for key, value in self.parameters.items():
+ if isinstance(value, ClauseElement):
+ vis.traverse(value)
+
def _process_colparams(self, parameters):
"""Receive the *values* of an ``INSERT`` or ``UPDATE``
statement and construct appropriate bind parameters.
@@ -3155,7 +3329,7 @@ class _UpdateBase(ClauseElement):
if parameters is None:
return None
- if isinstance(parameters, list) or isinstance(parameters, tuple):
+ if isinstance(parameters, (list, tuple)):
pp = {}
i = 0
for c in self.table.c:
@@ -3163,11 +3337,10 @@ class _UpdateBase(ClauseElement):
i +=1
parameters = pp
- correlator = _UpdateBase._SelectCorrelator(self.table)
for key in parameters.keys():
value = parameters[key]
if isinstance(value, ClauseElement):
- correlator.traverse(value)
+ parameters[key] = value.self_group()
elif _is_literal(value):
if _is_literal(key):
col = self.table.c[key]
@@ -3182,7 +3355,7 @@ class _UpdateBase(ClauseElement):
def _find_engine(self):
return self.table.bind
-class _Insert(_UpdateBase):
+class Insert(_UpdateBase):
def __init__(self, table, values=None):
self.table = table
self.select = None
@@ -3193,32 +3366,41 @@ class _Insert(_UpdateBase):
return self.select,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_insert(self)
-class _Update(_UpdateBase):
+class Update(_UpdateBase):
def __init__(self, table, whereclause, values=None):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
self.parameters = self._process_colparams(values)
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_update(self)
-class _Delete(_UpdateBase):
+class Delete(_UpdateBase):
def __init__(self, table, whereclause):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_delete(self)
+
+class _IdentifiedClause(ClauseElement):
+ def __init__(self, ident):
+ self.ident = ident
+ def supports_execution(self):
+ return True
+
+class SavepointClause(_IdentifiedClause):
+ pass
+
+class RollbackToSavepointClause(_IdentifiedClause):
+ pass
+
+class ReleaseSavepointClause(_IdentifiedClause):
+ pass
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py
index 9235b9c4e..d91fbe4b5 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql_util.py
@@ -53,7 +53,7 @@ class TableCollection(object):
for table in self.tables:
vis.traverse(table)
sorter = topological.QueueDependencySorter( tuples, self.tables )
- head = sorter.sort()
+ head = sorter.sort()
sequence = []
def to_sequence( node, seq=sequence):
seq.append( node.item )
@@ -67,12 +67,12 @@ class TableCollection(object):
class TableFinder(TableCollection, sql.NoColumnVisitor):
"""locate all Tables within a clause."""
- def __init__(self, table, check_columns=False, include_aliases=False):
+ def __init__(self, clause, check_columns=False, include_aliases=False):
TableCollection.__init__(self)
self.check_columns = check_columns
self.include_aliases = include_aliases
- if table is not None:
- self.traverse(table)
+ for clause in util.to_list(clause):
+ self.traverse(clause)
def visit_alias(self, alias):
if self.include_aliases:
@@ -83,7 +83,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor):
def visit_column(self, column):
if self.check_columns:
- self.traverse(column.table)
+ self.tables.append(column.table)
class ColumnFinder(sql.ClauseVisitor):
def __init__(self):
@@ -125,7 +125,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
process the new list.
"""
- list_ = [o.copy_container() for o in list_]
+ list_ = list(list_)
self.process_list(list_)
return list_
@@ -137,7 +137,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
if elem is not None:
list_[i] = elem
else:
- self.traverse(list_[i])
+ list_[i] = self.traverse(list_[i], clone=True)
def visit_grouping(self, grouping):
elem = self.convert_element(grouping.elem)
@@ -162,8 +162,24 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
elem = self.convert_element(binary.right)
if elem is not None:
binary.right = elem
-
- # TODO: visit_select().
+
+ def visit_select(self, select):
+ fr = util.OrderedSet()
+ for elem in select._froms:
+ n = self.convert_element(elem)
+ if n is not None:
+ fr.add((elem, n))
+ select._recorrelate_froms(fr)
+
+ col = []
+ for elem in select._raw_columns:
+ print "RAW COLUMN", elem
+ n = self.convert_element(elem)
+ if n is None:
+ col.append(elem)
+ else:
+ col.append(n)
+ select._raw_columns = col
class ClauseAdapter(AbstractClauseProcessor):
"""Given a clause (like as in a WHERE criterion), locate columns
@@ -200,6 +216,9 @@ class ClauseAdapter(AbstractClauseProcessor):
self.equivalents = equivalents
def convert_element(self, col):
+ if isinstance(col, sql.FromClause):
+ if self.selectable.is_derived_from(col):
+ return self.selectable
if not isinstance(col, sql.ColumnElement):
return None
if self.include is not None:
@@ -214,4 +233,9 @@ class ClauseAdapter(AbstractClauseProcessor):
newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False)
if newcol:
return newcol
+ #if newcol is None:
+ # self.traverse(col)
+ # return col
return newcol
+
+
diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py
index bad71293e..56c8cb46e 100644
--- a/lib/sqlalchemy/topological.py
+++ b/lib/sqlalchemy/topological.py
@@ -42,7 +42,6 @@ nature - very tricky to reproduce and track down, particularly before
I realized this characteristic of the algorithm.
"""
-import string, StringIO
from sqlalchemy import util
from sqlalchemy.exceptions import CircularDependencyError
@@ -68,7 +67,7 @@ class _Node(object):
str(self.item) + \
(self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \
"\n" + \
- string.join([n.safestr(indent + 1) for n in self.children], '')
+ ''.join([n.safestr(indent + 1) for n in self.children])
def __repr__(self):
return "%s" % (str(self.item))
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 3cceedae6..ec1459852 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -7,34 +7,29 @@
__all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'FLOAT', 'DECIMAL',
'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN', 'String', 'Integer', 'SmallInteger','Smallinteger',
- 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE',
+ 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE', 'NullType',
'SMALLINT', 'DATE', 'TIME','Interval'
]
-from sqlalchemy import util, exceptions
-import inspect, weakref
+import inspect
import datetime as dt
+from decimal import Decimal
try:
import cPickle as pickle
except:
import pickle
-_impl_cache = weakref.WeakKeyDictionary()
+from sqlalchemy import exceptions
class AbstractType(object):
- def _get_impl_dict(self):
- try:
- return _impl_cache[self]
- except KeyError:
- return _impl_cache.setdefault(self, {})
-
- impl_dict = property(_get_impl_dict)
-
+ def __init__(self, *args, **kwargs):
+ pass
+
def copy_value(self, value):
return value
def compare_values(self, x, y):
- return x is y
+ return x == y
def is_mutable(self):
return False
@@ -51,15 +46,20 @@ class AbstractType(object):
return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]]))
class TypeEngine(AbstractType):
- def __init__(self, *args, **params):
- pass
-
def dialect_impl(self, dialect):
try:
- return self.impl_dict[dialect]
+ return self._impl_dict[dialect]
+ except AttributeError:
+ self._impl_dict = {}
+ return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
except KeyError:
- return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self))
-
+ return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d['_impl_dict'] = {}
+ return d
+
def get_col_spec(self):
raise NotImplementedError()
@@ -88,15 +88,19 @@ class TypeDecorator(AbstractType):
def dialect_impl(self, dialect):
try:
- return self.impl_dict[dialect]
- except:
- typedesc = self.load_dialect_impl(dialect)
- tt = self.copy()
- if not isinstance(tt, self.__class__):
- raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
- tt.impl = typedesc
- self.impl_dict[dialect] = tt
- return tt
+ return self._impl_dict[dialect]
+ except AttributeError:
+ self._impl_dict = {}
+ except KeyError:
+ pass
+
+ typedesc = self.load_dialect_impl(dialect)
+ tt = self.copy()
+ if not isinstance(tt, self.__class__):
+ raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
+ tt.impl = typedesc
+ self._impl_dict[dialect] = tt
+ return tt
def load_dialect_impl(self, dialect):
"""loads the dialect-specific implementation of this type.
@@ -179,7 +183,7 @@ def adapt_type(typeobj, colspecs):
return typeobj
return typeobj.adapt(impltype)
-class NullTypeEngine(TypeEngine):
+class NullType(TypeEngine):
def get_col_spec(self):
raise NotImplementedError()
@@ -188,8 +192,13 @@ class NullTypeEngine(TypeEngine):
def convert_result_value(self, value, dialect):
return value
+NullTypeEngine = NullType
-class String(TypeEngine):
+class Concatenable(object):
+ """marks a type as supporting 'concatenation'"""
+ pass
+
+class String(TypeEngine, Concatenable):
def __init__(self, length=None, convert_unicode=False):
self.length = length
self.convert_unicode = convert_unicode
@@ -219,9 +228,6 @@ class String(TypeEngine):
def get_dbapi_type(self, dbapi):
return dbapi.STRING
- def compare_values(self, x, y):
- return x == y
-
class Unicode(String):
def __init__(self, length=None, **kwargs):
kwargs['convert_unicode'] = True
@@ -241,22 +247,36 @@ class SmallInteger(Integer):
Smallinteger = SmallInteger
class Numeric(TypeEngine):
- def __init__(self, precision = 10, length = 2):
+ def __init__(self, precision = 10, length = 2, asdecimal=True):
self.precision = precision
self.length = length
+ self.asdecimal = asdecimal
def adapt(self, impltype):
- return impltype(precision=self.precision, length=self.length)
+ return impltype(precision=self.precision, length=self.length, asdecimal=self.asdecimal)
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
+ def convert_bind_param(self, value, dialect):
+ if value is not None:
+ return float(value)
+ else:
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if value is not None and self.asdecimal:
+ return Decimal(str(value))
+ else:
+ return value
+
class Float(Numeric):
- def __init__(self, precision = 10):
+ def __init__(self, precision = 10, asdecimal=False, **kwargs):
self.precision = precision
-
+ self.asdecimal = asdecimal
+
def adapt(self, impltype):
- return impltype(precision=self.precision)
+ return impltype(precision=self.precision, asdecimal=self.asdecimal)
class DateTime(TypeEngine):
"""Implement a type for ``datetime.datetime()`` objects."""
@@ -416,4 +436,4 @@ class NCHAR(Unicode):pass
class BLOB(Binary): pass
class BOOLEAN(Boolean): pass
-NULLTYPE = NullTypeEngine()
+NULLTYPE = NullType()
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index b47822d61..e711de3a3 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -14,7 +14,6 @@ from sqlalchemy import exceptions
import md5
import sys
import warnings
-
import __builtin__
try:
@@ -33,10 +32,35 @@ except:
i -= 1
raise StopIteration()
-def to_list(x):
+if sys.version_info >= (2, 5):
+ class PopulateDict(dict):
+ """a dict which populates missing values via a creation function.
+
+ note the creation function takes a key, unlike collections.defaultdict.
+ """
+
+ def __init__(self, creator):
+ self.creator = creator
+ def __missing__(self, key):
+ self[key] = val = self.creator(key)
+ return val
+else:
+ class PopulateDict(dict):
+ """a dict which populates missing values via a creation function."""
+
+ def __init__(self, creator):
+ self.creator = creator
+ def __getitem__(self, key):
+ try:
+ return dict.__getitem__(self, key)
+ except KeyError:
+ self[key] = value = self.creator(key)
+ return value
+
+def to_list(x, default=None):
if x is None:
- return None
- if not isinstance(x, list) and not isinstance(x, tuple):
+ return default
+ if not isinstance(x, (list, tuple)):
return [x]
else:
return x
@@ -113,19 +137,25 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True):
else:
kw[key] = type_(kw[key])
-def duck_type_collection(col, default=None):
+def duck_type_collection(specimen, default=None):
"""Given an instance or class, guess if it is or is acting as one of
the basic collection types: list, set and dict. If the __emulates__
property is present, return that preferentially.
"""
- if hasattr(col, '__emulates__'):
- return getattr(col, '__emulates__')
- elif hasattr(col, 'append'):
+ if hasattr(specimen, '__emulates__'):
+ return specimen.__emulates__
+
+ isa = isinstance(specimen, type) and issubclass or isinstance
+ if isa(specimen, list): return list
+ if isa(specimen, Set): return Set
+ if isa(specimen, dict): return dict
+
+ if hasattr(specimen, 'append'):
return list
- elif hasattr(col, 'add'):
+ elif hasattr(specimen, 'add'):
return Set
- elif hasattr(col, 'set'):
+ elif hasattr(specimen, 'set'):
return dict
else:
return default
@@ -138,11 +168,11 @@ def assert_arg_type(arg, argtype, name):
raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
else:
raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
-
-def warn_exception(func):
+
+def warn_exception(func, *args, **kwargs):
"""executes the given function, catches all exceptions and converts to a warning."""
try:
- return func()
+ return func(*args, **kwargs)
except:
warnings.warn(RuntimeWarning("%s('%s') ignored" % sys.exc_info()[0:2]))
@@ -246,12 +276,12 @@ class OrderedProperties(object):
class OrderedDict(dict):
"""A Dictionary that returns keys/values/items in the order they were added."""
- def __init__(self, d=None, **kwargs):
+ def __init__(self, ____sequence=None, **kwargs):
self._list = []
- if d is None:
+ if ____sequence is None:
self.update(**kwargs)
else:
- self.update(d, **kwargs)
+ self.update(____sequence, **kwargs)
def clear(self):
self._list = []
@@ -347,7 +377,13 @@ class DictDecorator(dict):
return dict.__getitem__(self, key)
except KeyError:
return self.decorate[key]
-
+
+ def __contains__(self, key):
+ return dict.__contains__(self, key) or key in self.decorate
+
+ def has_key(self, key):
+ return key in self
+
def __repr__(self):
return dict.__repr__(self) + repr(self.decorate)
@@ -442,19 +478,28 @@ class OrderedSet(Set):
__isub__ = difference_update
class UniqueAppender(object):
- def __init__(self, data):
+ """appends items to a collection such that only unique items
+ are added."""
+
+ def __init__(self, data, via=None):
self.data = data
- if hasattr(data, 'append'):
+ self._unique = Set()
+ if via:
+ self._data_appender = getattr(data, via)
+ elif hasattr(data, 'append'):
self._data_appender = data.append
elif hasattr(data, 'add'):
+ # TODO: we think its a set here. bypass unneeded uniquing logic ?
self._data_appender = data.add
- self.set = Set()
-
+
def append(self, item):
- if item not in self.set:
- self.set.add(item)
+ if item not in self._unique:
self._data_appender(item)
-
+ self._unique.add(item)
+
+ def __iter__(self):
+ return iter(self.data)
+
class ScopedRegistry(object):
"""A Registry that can store one or multiple instances of a single
class on a per-thread scoped basis, or on a customized scope.
diff --git a/setup.py b/setup.py
index 092d0d508..735f3d723 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ use_setuptools()
from setuptools import setup, find_packages
setup(name = "SQLAlchemy",
- version = "0.3.11",
+ version = "0.4.0",
description = "Database Abstraction Library",
author = "Mike Bayer",
author_email = "mike_mp@zzzcomputing.com",
diff --git a/test/base/alltests.py b/test/base/alltests.py
index 70ff83ab8..44fa9b2ec 100644
--- a/test/base/alltests.py
+++ b/test/base/alltests.py
@@ -5,6 +5,7 @@ def suite():
modules_to_test = (
# core utilities
'base.dependency',
+ 'base.utils',
)
alltests = unittest.TestSuite()
for name in modules_to_test:
diff --git a/test/base/dependency.py b/test/base/dependency.py
index c5e54fc9f..ddadd1b31 100644
--- a/test/base/dependency.py
+++ b/test/base/dependency.py
@@ -1,7 +1,8 @@
-from testbase import PersistTest
+import testbase
import sqlalchemy.topological as topological
-import unittest, sys, os
from sqlalchemy import util
+from testlib import *
+
# TODO: need assertion conditions in this suite
@@ -190,4 +191,4 @@ class DependencySortTest(PersistTest):
if __name__ == "__main__":
- unittest.main()
+ testbase.main()
diff --git a/test/base/utils.py b/test/base/utils.py
new file mode 100644
index 000000000..97f3db06f
--- /dev/null
+++ b/test/base/utils.py
@@ -0,0 +1,67 @@
+import testbase
+from sqlalchemy import util, column, sql, exceptions
+from testlib import *
+
+
+class OrderedDictTest(PersistTest):
+ def test_odict(self):
+ o = util.OrderedDict()
+ o['a'] = 1
+ o['b'] = 2
+ o['snack'] = 'attack'
+ o['c'] = 3
+
+ self.assert_(o.keys() == ['a', 'b', 'snack', 'c'])
+ self.assert_(o.values() == [1, 2, 'attack', 3])
+
+ o.pop('snack')
+
+ self.assert_(o.keys() == ['a', 'b', 'c'])
+ self.assert_(o.values() == [1, 2, 3])
+
+ o2 = util.OrderedDict(d=4)
+ o2['e'] = 5
+
+ self.assert_(o2.keys() == ['d', 'e'])
+ self.assert_(o2.values() == [4, 5])
+
+ o.update(o2)
+ self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e'])
+ self.assert_(o.values() == [1, 2, 3, 4, 5])
+
+ o.setdefault('c', 'zzz')
+ o.setdefault('f', 6)
+ self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e', 'f'])
+ self.assert_(o.values() == [1, 2, 3, 4, 5, 6])
+
+class ColumnCollectionTest(PersistTest):
+ def test_in(self):
+ cc = sql.ColumnCollection()
+ cc.add(column('col1'))
+ cc.add(column('col2'))
+ cc.add(column('col3'))
+ assert 'col1' in cc
+ assert 'col2' in cc
+
+ try:
+ cc['col1'] in cc
+ assert False
+ except exceptions.ArgumentError, e:
+ assert str(e) == "__contains__ requires a string argument"
+
+ def test_compare(self):
+ cc1 = sql.ColumnCollection()
+ cc2 = sql.ColumnCollection()
+ cc3 = sql.ColumnCollection()
+ c1 = column('col1')
+ c2 = c1.label('col2')
+ c3 = column('col3')
+ cc1.add(c1)
+ cc2.add(c2)
+ cc3.add(c3)
+ assert (cc1==cc2).compare(c1 == c2)
+ assert not (cc1==cc3).compare(c2 == c3)
+
+
+if __name__ == "__main__":
+ testbase.main()
diff --git a/test/dialect/alltests.py b/test/dialect/alltests.py
index f4b39dd6f..890073625 100644
--- a/test/dialect/alltests.py
+++ b/test/dialect/alltests.py
@@ -5,6 +5,7 @@ def suite():
modules_to_test = (
'dialect.mysql',
'dialect.postgres',
+ 'dialect.oracle',
)
alltests = unittest.TestSuite()
for name in modules_to_test:
diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py
index d9227383f..dbba78893 100644
--- a/test/dialect/mysql.py
+++ b/test/dialect/mysql.py
@@ -1,15 +1,13 @@
-from testbase import PersistTest, AssertMixin
import testbase
from sqlalchemy import *
from sqlalchemy.databases import mysql
-import sys, StringIO
+from testlib import *
-db = testbase.db
class TypesTest(AssertMixin):
"Test MySQL column types"
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_numeric(self):
"Exercise type specification and options for numeric types."
@@ -104,13 +102,13 @@ class TypesTest(AssertMixin):
'SMALLINT(4) UNSIGNED ZEROFILL'),
]
- table_args = ['test_mysql_numeric', db]
+ table_args = ['test_mysql_numeric', MetaData(testbase.db)]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw)))
numeric_table = Table(*table_args)
- gen = db.dialect.schemagenerator(db, None, None)
+ gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
for col in numeric_table.c:
index = int(col.name[1:])
@@ -124,7 +122,7 @@ class TypesTest(AssertMixin):
raise
numeric_table.drop()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_charset(self):
"""Exercise CHARACTER SET and COLLATE-related options on string-type
columns."""
@@ -188,13 +186,13 @@ class TypesTest(AssertMixin):
'''ENUM('foo','bar') UNICODE''')
]
- table_args = ['test_mysql_charset', db]
+ table_args = ['test_mysql_charset', MetaData(testbase.db)]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw)))
charset_table = Table(*table_args)
- gen = db.dialect.schemagenerator(db, None, None)
+ gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
for col in charset_table.c:
index = int(col.name[1:])
@@ -208,11 +206,12 @@ class TypesTest(AssertMixin):
raise
charset_table.drop()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_enum(self):
"Exercise the ENUM type"
-
- enum_table = Table('mysql_enum', db,
+
+ db = testbase.db
+ enum_table = Table('mysql_enum', MetaData(testbase.db),
Column('e1', mysql.MSEnum('"a"', "'b'")),
Column('e2', mysql.MSEnum('"a"', "'b'"), nullable=False),
Column('e3', mysql.MSEnum('"a"', "'b'", strict=True)),
@@ -242,38 +241,17 @@ class TypesTest(AssertMixin):
enum_table.insert().execute(e1='a', e2='a', e3='a', e4='a')
enum_table.insert().execute(e1='b', e2='b', e3='b', e4='b')
- # Insert out of range enums, push stderr aside to avoid expected
- # warnings cluttering test output
- con = db.connect()
- if not hasattr(con.connection, 'show_warnings'):
- con.execute(insert(enum_table, {'e1':'c', 'e2':'c',
- 'e3':'a', 'e4':'a'}))
- else:
- try:
- aside = sys.stderr
- sys.stderr = StringIO.StringIO()
-
- self.assert_(not con.connection.show_warnings())
-
- con.execute(insert(enum_table, {'e1':'c', 'e2':'c',
- 'e3':'a', 'e4':'a'}))
-
- self.assert_(con.connection.show_warnings())
- finally:
- sys.stderr = aside
-
res = enum_table.select().execute().fetchall()
expected = [(None, 'a', None, 'a'),
('a', 'a', 'a', 'a'),
- ('b', 'b', 'b', 'b'),
- ('', '', 'a', 'a')]
+ ('b', 'b', 'b', 'b')]
# This is known to fail with MySQLDB 1.2.2 beta versions
# which return these as sets.Set(['a']), sets.Set(['b'])
# (even on Pythons with __builtin__.set)
- if db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
- db.dialect.dbapi.version_info >= (1, 2, 2):
+ if testbase.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
+ testbase.db.dialect.dbapi.version_info >= (1, 2, 2):
# these mysqldb seem to always uses 'sets', even on later pythons
import sets
def convert(value):
@@ -292,10 +270,10 @@ class TypesTest(AssertMixin):
self.assertEqual(res, expected)
enum_table.drop()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_type_reflection(self):
# FIXME: older versions need their own test
- if db.dialect.get_version_info(db) < (5, 0):
+ if testbase.db.dialect.get_version_info(testbase.db) < (5, 0):
return
# (ask_for, roundtripped_as_if_different)
@@ -325,12 +303,12 @@ class TypesTest(AssertMixin):
columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
- m = MetaData(db)
+ m = MetaData(testbase.db)
t_table = Table('mysql_types', m, *columns)
m.drop_all()
m.create_all()
- m2 = MetaData(db)
+ m2 = MetaData(testbase.db)
rt = Table('mysql_types', m2, autoload=True)
#print
diff --git a/test/dialect/oracle.py b/test/dialect/oracle.py
new file mode 100644
index 000000000..14de8960b
--- /dev/null
+++ b/test/dialect/oracle.py
@@ -0,0 +1,32 @@
+import testbase
+from sqlalchemy import *
+from sqlalchemy.databases import mysql
+
+from testlib import *
+
+
+class OutParamTest(AssertMixin):
+ @testing.supported('oracle')
+ def setUpAll(self):
+ testbase.db.execute("""
+create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number) IS
+ retval number;
+ begin
+ retval := 6;
+ x_out := 10;
+ y_out := x_in * 15;
+ end;
+ """)
+
+ @testing.supported('oracle')
+ def test_out_params(self):
+ result = testbase.db.execute(text("begin foo(:x, :y, :z); end;", bindparams=[bindparam('x', Numeric), outparam('y', Numeric), outparam('z', Numeric)]), x=5)
+ assert result.out_parameters == {'y':10, 'z':75}, result.out_parameters
+ print result.out_parameters
+
+ @testing.supported('oracle')
+ def tearDownAll(self):
+ testbase.db.execute("DROP PROCEDURE foo")
+
+if __name__ == '__main__':
+ testbase.main()
diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py
index 0507b7c5b..f80ddcadd 100644
--- a/test/dialect/postgres.py
+++ b/test/dialect/postgres.py
@@ -1,68 +1,68 @@
-from testbase import AssertMixin
import testbase
+import datetime
from sqlalchemy import *
from sqlalchemy.databases import postgres
-import datetime
+from testlib import *
-db = testbase.db
class DomainReflectionTest(AssertMixin):
"Test PostgreSQL domains"
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def setUpAll(self):
- self.con = db.connect()
- self.con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42')
- self.con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0')
- self.con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
- self.con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
- self.con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
+ con = testbase.db.connect()
+ con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42')
+ con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0')
+ con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
+ con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
+ con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def tearDownAll(self):
- self.con.execute('DROP TABLE testtable')
- self.con.execute('DROP TABLE alt_schema.testtable')
- self.con.execute('DROP TABLE crosschema')
- self.con.execute('DROP DOMAIN testdomain')
- self.con.execute('DROP DOMAIN alt_schema.testdomain')
+ con = testbase.db.connect()
+ con.execute('DROP TABLE testtable')
+ con.execute('DROP TABLE alt_schema.testtable')
+ con.execute('DROP TABLE crosschema')
+ con.execute('DROP DOMAIN testdomain')
+ con.execute('DROP DOMAIN alt_schema.testdomain')
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_table_is_reflected(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True)
self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_domain_is_reflected(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True)
self.assertEquals(str(table.columns.answer.default.arg), '42', "Reflected default value didn't equal expected value")
self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.")
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_table_is_reflected_alt_schema(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True, schema='alt_schema')
self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_domain_is_reflected(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True, schema='alt_schema')
self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value")
self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_crosschema_domain_is_reflected(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('crosschema', metadata, autoload=True)
self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value")
self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
class MiscTest(AssertMixin):
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_date_reflection(self):
m1 = MetaData(testbase.db)
t1 = Table('pgdate', m1,
@@ -78,7 +78,7 @@ class MiscTest(AssertMixin):
finally:
m1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_pg_weirdchar_reflection(self):
meta1 = MetaData(testbase.db)
subject = Table("subject", meta1,
@@ -99,18 +99,18 @@ class MiscTest(AssertMixin):
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_checksfor_sequence(self):
meta1 = MetaData(testbase.db)
t = Table('mytable', meta1,
Column('col1', Integer, Sequence('fooseq')))
try:
testbase.db.execute("CREATE SEQUENCE fooseq")
- t.create()
+ t.create(checkfirst=True)
finally:
- t.drop()
+ t.drop(checkfirst=True)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_reflection(self):
"""note: this test requires that the 'alt_schema' schema be separate and accessible by the test user"""
@@ -141,7 +141,7 @@ class MiscTest(AssertMixin):
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_reflection_2(self):
meta1 = MetaData(testbase.db)
subject = Table("subject", meta1,
@@ -162,7 +162,7 @@ class MiscTest(AssertMixin):
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_reflection_3(self):
meta1 = MetaData(testbase.db)
subject = Table("subject", meta1,
@@ -185,7 +185,7 @@ class MiscTest(AssertMixin):
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_preexecute_passivedefault(self):
"""test that when we get a primary key column back
from reflecting a table which has a default value on it, we pre-execute
@@ -216,7 +216,7 @@ class TimezoneTest(AssertMixin):
if postgres returns it. python then will not let you compare a datetime with a tzinfo to a datetime
that doesnt have one. this test illustrates two ways to have datetime types with and without timezone
info. """
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def setUpAll(self):
global tztable, notztable, metadata
metadata = MetaData(testbase.db)
@@ -233,11 +233,11 @@ class TimezoneTest(AssertMixin):
Column("name", String(20)),
)
metadata.create_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def tearDownAll(self):
metadata.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_with_timezone(self):
# get a date with a tzinfo
somedate = testbase.db.connect().scalar(func.current_timestamp().select())
@@ -246,7 +246,7 @@ class TimezoneTest(AssertMixin):
x = c.last_updated_params()
print x['date'] == somedate
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_without_timezone(self):
# get a date without a tzinfo
somedate = datetime.datetime(2005, 10,20, 11, 52, 00)
@@ -255,6 +255,56 @@ class TimezoneTest(AssertMixin):
x = c.last_updated_params()
print x['date'] == somedate
+class ArrayTest(AssertMixin):
+ @testing.supported('postgres')
+ def setUpAll(self):
+ global metadata, arrtable
+ metadata = MetaData(testbase.db)
+
+ arrtable = Table('arrtable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('intarr', postgres.PGArray(Integer)),
+ Column('strarr', postgres.PGArray(String), nullable=False)
+ )
+ metadata.create_all()
+ @testing.supported('postgres')
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ @testing.supported('postgres')
+ def test_reflect_array_column(self):
+ metadata2 = MetaData(testbase.db)
+ tbl = Table('arrtable', metadata2, autoload=True)
+ self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray))
+ self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray))
+ self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer))
+ self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String))
+
+ @testing.supported('postgres')
+ def test_insert_array(self):
+ arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
+ results = arrtable.select().execute().fetchall()
+ self.assertEquals(len(results), 1)
+ self.assertEquals(results[0]['intarr'], [1,2,3])
+ self.assertEquals(results[0]['strarr'], ['abc','def'])
+ arrtable.delete().execute()
+
+ @testing.supported('postgres')
+ def test_array_where(self):
+ arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
+ arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
+ results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall()
+ self.assertEquals(len(results), 1)
+ self.assertEquals(results[0]['intarr'], [1,2,3])
+ arrtable.delete().execute()
+ @testing.supported('postgres')
+ def test_array_concat(self):
+ arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
+ results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall()
+ self.assertEquals(len(results), 1)
+ self.assertEquals(results[0][0], [1,2,3,4,5,6])
+ arrtable.delete().execute()
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/engine/alltests.py b/test/engine/alltests.py
index ec8a47390..a34a82ed7 100644
--- a/test/engine/alltests.py
+++ b/test/engine/alltests.py
@@ -10,12 +10,12 @@ def suite():
'engine.bind',
'engine.reconnect',
'engine.execute',
+ 'engine.metadata',
'engine.transaction',
# schema/tables
'engine.reflection',
- 'engine.proxy_engine'
)
alltests = unittest.TestSuite()
for name in modules_to_test:
diff --git a/test/engine/autoconnect_engine.py b/test/engine/autoconnect_engine.py
deleted file mode 100644
index 69c2c33f5..000000000
--- a/test/engine/autoconnect_engine.py
+++ /dev/null
@@ -1,90 +0,0 @@
-from testbase import PersistTest
-import testbase
-from sqlalchemy import *
-from sqlalchemy.ext.proxy import AutoConnectEngine
-
-import os
-
-#
-# Define an engine, table and mapper at the module level, to show that the
-# table and mapper can be used with different real engines in multiple threads
-#
-
-
-module_engine = AutoConnectEngine( testbase.db_uri )
-users = Table('users', module_engine,
- Column('user_id', Integer, primary_key=True),
- Column('user_name', String(16)),
- Column('password', String(20))
- )
-
-class User(object):
- pass
-
-
-class AutoConnectEngineTest1(PersistTest):
-
- def setUp(self):
- clear_mappers()
- objectstore.clear()
-
- def test_engine_connect(self):
- users.create()
- assign_mapper(User, users)
- try:
- trans = objectstore.begin()
-
- user = User()
- user.user_name='fred'
- user.password='*'
- trans.commit()
-
- # select
- sqluser = User.select_by(user_name='fred')[0]
- assert sqluser.user_name == 'fred'
-
- # modify
- sqluser.user_name = 'fred jones'
-
- # commit - saves everything that changed
- objectstore.commit()
-
- allusers = [ user.user_name for user in User.select() ]
- assert allusers == [ 'fred jones' ]
- finally:
- users.drop()
-
-
-
-
-if __name__ == "__main__":
- testbase.main()
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/test/engine/bind.py b/test/engine/bind.py
index b9e53e6b1..6a0c78f57 100644
--- a/test/engine/bind.py
+++ b/test/engine/bind.py
@@ -2,12 +2,10 @@
including the deprecated versions of these arguments"""
import testbase
-import unittest, sys, datetime
-import tables
-db = testbase.db
from sqlalchemy import *
+from testlib import *
-class BindTest(testbase.PersistTest):
+class BindTest(PersistTest):
def test_create_drop_explicit(self):
metadata = MetaData()
table = Table('test_table', metadata,
@@ -17,7 +15,6 @@ class BindTest(testbase.PersistTest):
testbase.db.connect()
):
for args in [
- ([], {'connectable':bind}),
([], {'bind':bind}),
([bind], {})
]:
@@ -57,7 +54,7 @@ class BindTest(testbase.PersistTest):
table = Table('test_table', metadata,
Column('foo', Integer))
metadata.bind = bind
- assert metadata.bind is metadata.engine is table.bind is table.engine is bind
+ assert metadata.bind is table.bind is bind
metadata.create_all()
assert table.exists()
metadata.drop_all()
@@ -70,7 +67,7 @@ class BindTest(testbase.PersistTest):
Column('foo', Integer))
metadata.connect(bind)
- assert metadata.bind is metadata.engine is table.bind is table.engine is bind
+ assert metadata.bind is table.bind is bind
metadata.create_all()
assert table.exists()
metadata.drop_all()
@@ -88,15 +85,12 @@ class BindTest(testbase.PersistTest):
try:
for args in (
([bind], {}),
- ([], {'engine_or_url':bind}),
([], {'bind':bind}),
- ([], {'engine':bind})
):
metadata = MetaData(*args[0], **args[1])
table = Table('test_table', metadata,
- Column('foo', Integer))
-
- assert metadata.bind is metadata.engine is table.bind is table.engine is bind
+ Column('foo', Integer))
+ assert metadata.bind is table.bind is bind
metadata.create_all()
assert table.exists()
metadata.drop_all()
@@ -111,7 +105,8 @@ class BindTest(testbase.PersistTest):
metadata = MetaData()
table = Table('test_table', metadata,
Column('foo', Integer),
- mysql_engine='InnoDB')
+ test_needs_acid=True,
+ )
conn = testbase.db.connect()
metadata.create_all(bind=conn)
try:
@@ -124,7 +119,7 @@ class BindTest(testbase.PersistTest):
table.insert().execute(foo=7)
trans.rollback()
metadata.bind = None
- assert testbase.db.execute("select count(1) from test_table").scalar() == 0
+ assert conn.execute("select count(1) from test_table").scalar() == 0
finally:
metadata.drop_all(bind=conn)
@@ -147,10 +142,7 @@ class BindTest(testbase.PersistTest):
):
try:
e = elem(bind=bind)
- assert e.bind is e.engine is bind
- e.execute()
- e = elem(engine=bind)
- assert e.bind is e.engine is bind
+ assert e.bind is bind
e.execute()
finally:
if isinstance(bind, engine.Connection):
@@ -158,16 +150,19 @@ class BindTest(testbase.PersistTest):
try:
e = elem()
- assert e.bind is e.engine is None
+ assert e.bind is None
e.execute()
assert False
except exceptions.InvalidRequestError, e:
assert str(e) == "This Compiled object is not bound to any Engine or Connection."
-
+
finally:
+ if isinstance(bind, engine.Connection):
+ bind.close()
metadata.drop_all(bind=testbase.db)
def test_session(self):
+ from sqlalchemy.orm import create_session, mapper
metadata = MetaData()
table = Table('test_table', metadata,
Column('foo', Integer, primary_key=True),
@@ -177,11 +172,13 @@ class BindTest(testbase.PersistTest):
mapper(Foo, table)
metadata.create_all(bind=testbase.db)
try:
- for bind in (testbase.db, testbase.db.connect()):
+ for bind in (testbase.db,
+ testbase.db.connect()
+ ):
try:
- for args in ({'bind':bind}, {'bind_to':bind}):
+ for args in ({'bind':bind},):
sess = create_session(**args)
- assert sess.bind is sess.bind_to is bind
+ assert sess.bind is bind
f = Foo()
sess.save(f)
sess.flush()
@@ -189,6 +186,9 @@ class BindTest(testbase.PersistTest):
finally:
if isinstance(bind, engine.Connection):
bind.close()
+
+ if isinstance(bind, engine.Connection):
+ bind.close()
sess = create_session()
f = Foo()
@@ -198,8 +198,9 @@ class BindTest(testbase.PersistTest):
assert False
except exceptions.InvalidRequestError, e:
assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
-
finally:
+ if isinstance(bind, engine.Connection):
+ bind.close()
metadata.drop_all(bind=testbase.db)
diff --git a/test/engine/execute.py b/test/engine/execute.py
index 283006cfa..3d3b43f9b 100644
--- a/test/engine/execute.py
+++ b/test/engine/execute.py
@@ -1,19 +1,14 @@
-
import testbase
-import unittest, sys, datetime
-import tables
-db = testbase.db
from sqlalchemy import *
+from testlib import *
-
-class ExecuteTest(testbase.PersistTest):
+class ExecuteTest(PersistTest):
def setUpAll(self):
global users, metadata
metadata = MetaData(testbase.db)
users = Table('users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20)),
- mysql_engine='InnoDB'
)
metadata.create_all()
@@ -22,7 +17,7 @@ class ExecuteTest(testbase.PersistTest):
def tearDownAll(self):
metadata.drop_all()
- @testbase.supported('sqlite')
+ @testing.supported('sqlite')
def test_raw_qmark(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
@@ -34,7 +29,7 @@ class ExecuteTest(testbase.PersistTest):
assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
conn.execute("delete from users")
- @testbase.supported('mysql', 'postgres')
+ @testing.supported('mysql', 'postgres')
def test_raw_sprintf(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (%s, %s)", [1,"jack"])
@@ -47,7 +42,7 @@ class ExecuteTest(testbase.PersistTest):
# pyformat is supported for mysql, but skipping because a few driver
# versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_raw_python(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
@@ -57,7 +52,7 @@ class ExecuteTest(testbase.PersistTest):
assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
conn.execute("delete from users")
- @testbase.supported('sqlite')
+ @testing.supported('sqlite')
def test_raw_named(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'})
diff --git a/test/engine/metadata.py b/test/engine/metadata.py
new file mode 100644
index 000000000..973007fab
--- /dev/null
+++ b/test/engine/metadata.py
@@ -0,0 +1,18 @@
+import testbase
+from sqlalchemy import *
+from testlib import *
+
+class MetaDataTest(PersistTest):
+ def test_metadata_connect(self):
+ metadata = MetaData()
+ t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True),
+ Column('col2', String(20)))
+ metadata.bind = testbase.db
+ metadata.create_all()
+ try:
+ assert t1.count().scalar() == 0
+ finally:
+ metadata.drop_all()
+
+if __name__ == '__main__':
+ testbase.main()
diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py
index 967a20ed5..3e186275d 100644
--- a/test/engine/parseconnect.py
+++ b/test/engine/parseconnect.py
@@ -1,8 +1,7 @@
-from testbase import PersistTest
import testbase
-import sqlalchemy.engine.url as url
from sqlalchemy import *
-import unittest
+import sqlalchemy.engine.url as url
+from testlib import *
class ParseConnectTest(PersistTest):
@@ -65,7 +64,7 @@ class CreateEngineTest(PersistTest):
def testrecycle(self):
dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
e = create_engine('postgres://', pool_recycle=472, module=dbapi)
- assert e.connection_provider._pool._recycle == 472
+ assert e.pool._recycle == 472
def testbadargs(self):
# good arg, use MockDBAPI to prevent oracle import errors
@@ -116,7 +115,6 @@ class CreateEngineTest(PersistTest):
except TypeError:
assert True
- e = create_engine('sqlite://', echo=True)
e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
@@ -139,8 +137,8 @@ class CreateEngineTest(PersistTest):
def testpoolargs(self):
"""test that connection pool args make it thru"""
e = create_engine('postgres://', creator=None, pool_recycle=-1, echo_pool=None, auto_close_cursors=False, disallow_open_cursors=True, module=MockDBAPI())
- assert e.connection_provider._pool.auto_close_cursors is False
- assert e.connection_provider._pool.disallow_open_cursors is True
+ assert e.pool.auto_close_cursors is False
+ assert e.pool.disallow_open_cursors is True
# these args work for QueuePool
e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI())
diff --git a/test/engine/pool.py b/test/engine/pool.py
index 85e9d59fd..364afa9d7 100644
--- a/test/engine/pool.py
+++ b/test/engine/pool.py
@@ -1,10 +1,9 @@
import testbase
-from testbase import PersistTest
-import unittest, sys, os, time
-import threading, thread
-
+import threading, thread, time
import sqlalchemy.pool as pool
import sqlalchemy.exceptions as exceptions
+from testlib import *
+
mcid = 1
class MockDBAPI(object):
@@ -45,7 +44,7 @@ class PoolTest(PersistTest):
connection2 = manager.connect('foo.db')
connection3 = manager.connect('bar.db')
- self.echo( "connection " + repr(connection))
+ print "connection " + repr(connection)
self.assert_(connection.cursor() is not None)
self.assert_(connection is connection2)
self.assert_(connection2 is not connection3)
@@ -64,7 +63,7 @@ class PoolTest(PersistTest):
connection = manager.connect('foo.db')
connection2 = manager.connect('foo.db')
- self.echo( "connection " + repr(connection))
+ print "connection " + repr(connection)
self.assert_(connection.cursor() is not None)
self.assert_(connection is not connection2)
@@ -80,7 +79,7 @@ class PoolTest(PersistTest):
def status(pool):
tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
- self.echo( "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup)
+ print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
return tup
c1 = p.connect()
@@ -160,7 +159,7 @@ class PoolTest(PersistTest):
print timeouts
assert len(timeouts) > 0
for t in timeouts:
- assert abs(t - 3) < 1
+ assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
def _test_overflow(self, thread_count, max_overflow):
def creator():
@@ -352,6 +351,35 @@ class PoolTest(PersistTest):
c2 = None
c1 = None
self.assert_(p.checkedout() == 0)
+
+ def test_properties(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
+ pool_size=1, max_overflow=0)
+
+ c = p.connect()
+ self.assert_(not c.properties)
+ self.assert_(c.properties is c._connection_record.properties)
+
+ c.properties['foo'] = 'bar'
+ c.close()
+ del c
+
+ c = p.connect()
+ self.assert_('foo' in c.properties)
+
+ c.invalidate()
+ c = p.connect()
+ self.assert_('foo' not in c.properties)
+
+ c.properties['foo2'] = 'bar2'
+ c.detach()
+ self.assert_('foo2' in c.properties)
+
+ c2 = p.connect()
+ self.assert_(c.connection is not c2.connection)
+ self.assert_(not c2.properties)
+ self.assert_('foo2' in c.properties)
def tearDown(self):
pool.clear_managers()
diff --git a/test/engine/proxy_engine.py b/test/engine/proxy_engine.py
deleted file mode 100644
index 26b738e41..000000000
--- a/test/engine/proxy_engine.py
+++ /dev/null
@@ -1,204 +0,0 @@
-from testbase import PersistTest
-import testbase
-import os
-
-from sqlalchemy import *
-from sqlalchemy.ext.proxy import ProxyEngine
-
-
-#
-# Define an engine, table and mapper at the module level, to show that the
-# table and mapper can be used with different real engines in multiple threads
-#
-
-
-class ProxyTestBase(PersistTest):
- def setUpAll(self):
-
- global users, User, module_engine, module_metadata
-
- module_engine = ProxyEngine(echo=testbase.echo)
- module_metadata = MetaData()
-
- users = Table('users', module_metadata,
- Column('user_id', Integer, primary_key=True),
- Column('user_name', String(16)),
- Column('password', String(20))
- )
-
- class User(object):
- pass
-
- User.mapper = mapper(User, users)
- def tearDownAll(self):
- clear_mappers()
-
-class ConstructTest(ProxyTestBase):
- """tests that we can build SQL constructs without engine-specific parameters, particulary
- oid_column, being needed, as the proxy engine is usually not connected yet."""
-
- def test_join(self):
- engine = ProxyEngine()
- t = Table('table1', engine,
- Column('col1', Integer, primary_key=True))
- t2 = Table('table2', engine,
- Column('col2', Integer, ForeignKey('table1.col1')))
- j = join(t, t2)
-
-
-class ProxyEngineTest1(ProxyTestBase):
-
- def test_engine_connect(self):
- # connect to a real engine
- module_engine.connect(testbase.db_uri)
- module_metadata.create_all(module_engine)
-
- session = create_session(bind_to=module_engine)
- try:
-
- user = User()
- user.user_name='fred'
- user.password='*'
-
- session.save(user)
- session.flush()
-
- query = session.query(User)
-
- # select
- sqluser = query.select_by(user_name='fred')[0]
- assert sqluser.user_name == 'fred'
-
- # modify
- sqluser.user_name = 'fred jones'
-
- # flush - saves everything that changed
- session.flush()
-
- allusers = [ user.user_name for user in query.select() ]
- assert allusers == ['fred jones']
-
- finally:
- module_metadata.drop_all(module_engine)
-
-
-class ThreadProxyTest(ProxyTestBase):
-
- def tearDownAll(self):
- try:
- os.remove('threadtesta.db')
- except OSError:
- pass
- try:
- os.remove('threadtestb.db')
- except OSError:
- pass
-
- @testbase.supported('sqlite')
- def test_multi_thread(self):
-
- from threading import Thread
- from Queue import Queue
-
- # start 2 threads with different connection params
- # and perform simultaneous operations, showing that the
- # 2 threads don't share a connection
- qa = Queue()
- qb = Queue()
- def run(db_uri, uname, queue):
- def test():
-
- try:
- module_engine.connect(db_uri)
- module_metadata.create_all(module_engine)
- try:
- session = create_session(bind_to=module_engine)
-
- query = session.query(User)
-
- all = list(query.select())
- assert all == []
-
- u = User()
- u.user_name = uname
- u.password = 'whatever'
-
- session.save(u)
- session.flush()
-
- names = [u.user_name for u in query.select()]
- assert names == [uname]
- finally:
- module_metadata.drop_all(module_engine)
- module_engine.get_engine().dispose()
- except Exception, e:
- import traceback
- traceback.print_exc()
- queue.put(e)
- else:
- queue.put(False)
- return test
-
- a = Thread(target=run('sqlite:///threadtesta.db', 'jim', qa))
- b = Thread(target=run('sqlite:///threadtestb.db', 'joe', qb))
-
- a.start()
- b.start()
-
- # block and wait for the threads to push their results
- res = qa.get()
- if res != False:
- raise res
-
- res = qb.get()
- if res != False:
- raise res
-
-
-class ProxyEngineTest2(ProxyTestBase):
-
- def test_table_singleton_a(self):
- """set up for table singleton check
- """
- #
- # For this 'test', create a proxy engine instance, connect it
- # to a real engine, and make it do some work
- #
- engine = ProxyEngine()
- cats = Table('cats', engine,
- Column('cat_id', Integer, primary_key=True),
- Column('cat_name', String))
-
- engine.connect(testbase.db_uri)
-
- cats.create(engine)
- cats.drop(engine)
-
- ProxyEngineTest2.cats_table_a = cats
- assert isinstance(cats, Table)
-
- def test_table_singleton_b(self):
- """check that a table on a 2nd proxy engine instance gets 2nd table
- instance
- """
- #
- # Now create a new proxy engine instance and attach the same
- # table as the first test. This should result in 2 table instances,
- # since different proxy engine instances can't attach to the
- # same table instance
- #
- engine = ProxyEngine()
- cats = Table('cats', engine,
- Column('cat_id', Integer, primary_key=True),
- Column('cat_name', String))
- assert id(cats) != id(ProxyEngineTest2.cats_table_a)
-
- # the real test -- if we're still using the old engine reference,
- # this will fail because the old reference's local storage will
- # not have the default attributes
- engine.connect(testbase.db_uri)
- cats.create(engine)
- cats.drop(engine)
-
-if __name__ == "__main__":
- testbase.main()
diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py
index defc878ab..7c213695f 100644
--- a/test/engine/reconnect.py
+++ b/test/engine/reconnect.py
@@ -1,6 +1,8 @@
import testbase
+import sys, weakref
from sqlalchemy import create_engine, exceptions
-import gc, weakref, sys
+from testlib import *
+
class MockDisconnect(Exception):
pass
@@ -37,7 +39,7 @@ class MockCursor(object):
def close(self):
pass
-class ReconnectTest(testbase.PersistTest):
+class ReconnectTest(PersistTest):
def test_reconnect(self):
"""test that an 'is_disconnect' condition will invalidate the connection, and additionally
dispose the previous connection pool and recreate."""
@@ -50,7 +52,7 @@ class ReconnectTest(testbase.PersistTest):
# monkeypatch disconnect checker
db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
- pid = id(db.connection_provider._pool)
+ pid = id(db.pool)
# make a connection
conn = db.connect()
@@ -81,7 +83,7 @@ class ReconnectTest(testbase.PersistTest):
# close shouldnt break
conn.close()
- assert id(db.connection_provider._pool) != pid
+ assert id(db.pool) != pid
# ensure all connections closed (pool was recycled)
assert len(dbapi.connections) == 0
@@ -92,4 +94,4 @@ class ReconnectTest(testbase.PersistTest):
assert len(dbapi.connections) == 1
if __name__ == '__main__':
- testbase.main() \ No newline at end of file
+ testbase.main()
diff --git a/test/engine/reflection.py b/test/engine/reflection.py
index 74ae75e2e..00c1276ee 100644
--- a/test/engine/reflection.py
+++ b/test/engine/reflection.py
@@ -1,13 +1,12 @@
-from testbase import PersistTest
import testbase
-import pickle
-import sqlalchemy.ansisql as ansisql
+import pickle, StringIO
from sqlalchemy import *
+import sqlalchemy.ansisql as ansisql
from sqlalchemy.exceptions import NoSuchTableError
import sqlalchemy.databases.mysql as mysql
+from testlib import *
-import unittest, re, StringIO
class ReflectionTest(PersistTest):
def testbasic(self):
@@ -15,6 +14,10 @@ class ReflectionTest(PersistTest):
use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite')
+ if (testbase.db.engine.name == 'mysql' and
+ testbase.db.dialect.get_version_info(testbase.db) < (4, 1, 1)):
+ return
+
if use_function_defaults:
defval = func.current_date()
deftype = Date
@@ -54,14 +57,14 @@ class ReflectionTest(PersistTest):
Column('test_passivedefault4', deftype3, PassiveDefault(defval3)),
Column('test9', Binary(100)),
Column('test_numeric', Numeric(None, None)),
- mysql_engine='InnoDB'
+ test_needs_fk=True,
)
-
+
addresses = Table('engine_email_addresses', meta,
Column('address_id', Integer, primary_key = True),
Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
Column('email_address', String(20)),
- mysql_engine='InnoDB'
+ test_needs_fk=True,
)
meta.drop_all()
@@ -106,6 +109,29 @@ class ReflectionTest(PersistTest):
addresses.drop()
users.drop()
+ def test_autoload_partial(self):
+ meta = MetaData(testbase.db)
+ foo = Table('foo', meta,
+ Column('a', String(30)),
+ Column('b', String(30)),
+ Column('c', String(30)),
+ Column('d', String(30)),
+ Column('e', String(30)),
+ Column('f', String(30)),
+ )
+ meta.create_all()
+ try:
+ meta2 = MetaData(testbase.db)
+ foo2 = Table('foo', meta2, autoload=True, include_columns=['b', 'f', 'e'])
+ # test that cols come back in original order
+ assert [c.name for c in foo2.c] == ['b', 'e', 'f']
+ for c in ('b', 'f', 'e'):
+ assert c in foo2.c
+ for c in ('a', 'c', 'd'):
+ assert c not in foo2.c
+ finally:
+ meta.drop_all()
+
def testoverridecolumns(self):
"""test that you can override columns which contain foreign keys to other reflected tables"""
meta = MetaData(testbase.db)
@@ -203,7 +229,7 @@ class ReflectionTest(PersistTest):
finally:
meta.drop_all()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def testmysqltypes(self):
meta1 = MetaData(testbase.db)
table = Table(
@@ -250,7 +276,7 @@ class ReflectionTest(PersistTest):
PRIMARY KEY(id)
)""")
try:
- metadata = MetaData(engine=testbase.db)
+ metadata = MetaData(bind=testbase.db)
book = Table('book', metadata, autoload=True)
assert book.c.id in book.primary_key
assert book.c.series not in book.primary_key
@@ -271,7 +297,7 @@ class ReflectionTest(PersistTest):
PRIMARY KEY(id, isbn)
)""")
try:
- metadata = MetaData(engine=testbase.db)
+ metadata = MetaData(bind=testbase.db)
book = Table('book', metadata, autoload=True)
assert book.c.id in book.primary_key
assert book.c.isbn in book.primary_key
@@ -280,7 +306,7 @@ class ReflectionTest(PersistTest):
finally:
testbase.db.execute("drop table book")
- @testbase.supported('sqlite')
+ @testing.supported('sqlite')
def test_goofy_sqlite(self):
"""test autoload of table where quotes were used with all the colnames. quirky in sqlite."""
testbase.db.execute("""CREATE TABLE "django_content_type" (
@@ -309,7 +335,12 @@ class ReflectionTest(PersistTest):
def test_composite_fk(self):
"""test reflection of composite foreign keys"""
+
+ if (testbase.db.engine.name == 'mysql' and
+ testbase.db.dialect.get_version_info(testbase.db) < (4, 1, 1)):
+ return
meta = MetaData(testbase.db)
+
table = Table(
'multi', meta,
Column('multi_id', Integer, primary_key=True),
@@ -317,7 +348,7 @@ class ReflectionTest(PersistTest):
Column('multi_hoho', Integer, primary_key=True),
Column('name', String(50), nullable=False),
Column('val', String(100)),
- mysql_engine='InnoDB'
+ test_needs_fk=True,
)
table2 = Table('multi2', meta,
Column('id', Integer, primary_key=True),
@@ -326,7 +357,7 @@ class ReflectionTest(PersistTest):
Column('lala', Integer),
Column('data', String(50)),
ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']),
- mysql_engine='InnoDB'
+ test_needs_fk=True,
)
assert table.c.multi_hoho
meta.create_all()
@@ -345,7 +376,6 @@ class ReflectionTest(PersistTest):
finally:
meta.drop_all()
-
def test_to_metadata(self):
meta = MetaData()
@@ -372,17 +402,17 @@ class ReflectionTest(PersistTest):
def test_pickle():
meta.connect(testbase.db)
meta2 = pickle.loads(pickle.dumps(meta))
- assert meta2.engine is None
+ assert meta2.bind is None
return (meta2.tables['mytable'], meta2.tables['othertable'])
def test_pickle_via_reflect():
# this is the most common use case, pickling the results of a
# database reflection
- meta2 = MetaData(engine=testbase.db)
+ meta2 = MetaData(bind=testbase.db)
t1 = Table('mytable', meta2, autoload=True)
t2 = Table('othertable', meta2, autoload=True)
meta3 = pickle.loads(pickle.dumps(meta2))
- assert meta3.engine is None
+ assert meta3.bind is None
assert meta3.tables['mytable'] is not t1
return (meta3.tables['mytable'], meta3.tables['othertable'])
@@ -392,6 +422,8 @@ class ReflectionTest(PersistTest):
table_c, table2_c = test()
assert table is not table_c
assert table_c.c.myid.primary_key
+ assert isinstance(table_c.c.myid.type, Integer)
+ assert isinstance(table_c.c.name.type, String)
assert not table_c.c.name.nullable
assert table_c.c.description.nullable
assert table.primary_key is not table_c.primary_key
@@ -418,14 +450,10 @@ class ReflectionTest(PersistTest):
finally:
meta.drop_all(testbase.db)
- # mysql throws its own exception for no such table, resulting in
- # a sqlalchemy.SQLError instead of sqlalchemy.NoSuchTableError.
- # this could probably be fixed at some point.
- @testbase.unsupported('mysql')
def test_nonexistent(self):
self.assertRaises(NoSuchTableError, Table,
'fake_table',
- testbase.db, autoload=True)
+ MetaData(testbase.db), autoload=True)
def testoverride(self):
meta = MetaData(testbase.db)
@@ -452,7 +480,7 @@ class ReflectionTest(PersistTest):
finally:
table.drop()
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def testidentity(self):
meta = MetaData(testbase.db)
table = Table(
@@ -505,20 +533,6 @@ class ReflectionTest(PersistTest):
finally:
meta.drop_all()
-
- meta = MetaData(testbase.db)
- table = Table(
- 'select', meta,
- Column('col1', Integer, primary_key=True)
- )
- table.create()
-
- meta2 = MetaData(testbase.db)
- try:
- table2 = Table('select', meta2, autoload=True)
- finally:
- table.drop()
-
class CreateDropTest(PersistTest):
def setUpAll(self):
global metadata, users
@@ -558,33 +572,33 @@ class CreateDropTest(PersistTest):
def testcheckfirst(self):
try:
assert not users.exists(testbase.db)
- users.create(connectable=testbase.db)
+ users.create(bind=testbase.db)
assert users.exists(testbase.db)
- users.create(connectable=testbase.db, checkfirst=True)
- users.drop(connectable=testbase.db)
- users.drop(connectable=testbase.db, checkfirst=True)
- assert not users.exists(connectable=testbase.db)
- users.create(connectable=testbase.db, checkfirst=True)
- users.drop(connectable=testbase.db)
+ users.create(bind=testbase.db, checkfirst=True)
+ users.drop(bind=testbase.db)
+ users.drop(bind=testbase.db, checkfirst=True)
+ assert not users.exists(bind=testbase.db)
+ users.create(bind=testbase.db, checkfirst=True)
+ users.drop(bind=testbase.db)
finally:
- metadata.drop_all(connectable=testbase.db)
+ metadata.drop_all(bind=testbase.db)
def test_createdrop(self):
- metadata.create_all(connectable=testbase.db)
+ metadata.create_all(bind=testbase.db)
self.assertEqual( testbase.db.has_table('items'), True )
self.assertEqual( testbase.db.has_table('email_addresses'), True )
- metadata.create_all(connectable=testbase.db)
+ metadata.create_all(bind=testbase.db)
self.assertEqual( testbase.db.has_table('items'), True )
- metadata.drop_all(connectable=testbase.db)
+ metadata.drop_all(bind=testbase.db)
self.assertEqual( testbase.db.has_table('items'), False )
self.assertEqual( testbase.db.has_table('email_addresses'), False )
- metadata.drop_all(connectable=testbase.db)
+ metadata.drop_all(bind=testbase.db)
self.assertEqual( testbase.db.has_table('items'), False )
class SchemaTest(PersistTest):
# this test should really be in the sql tests somewhere, not engine
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testiteration(self):
metadata = MetaData()
table1 = Table('table1', metadata,
@@ -607,14 +621,17 @@ class SchemaTest(PersistTest):
print buf
assert buf.index("CREATE TABLE someschema.table1") > -1
assert buf.index("CREATE TABLE someschema.table2") > -1
-
- @testbase.unsupported('sqlite', 'postgres')
- def test_create_with_defaultschema(self):
+
+ @testing.supported('mysql','postgres')
+ def testcreate(self):
engine = testbase.db
schema = engine.dialect.get_default_schema_name(engine)
+ #engine.echo = True
- # test reflection of tables with an explcit schemaname
- # matching the default
+ if testbase.db.name == 'mysql':
+ schema = testbase.db.url.database
+ else:
+ schema = 'public'
metadata = MetaData(testbase.db)
table1 = Table('table1', metadata,
Column('col1', Integer, primary_key=True),
@@ -628,10 +645,7 @@ class SchemaTest(PersistTest):
metadata.clear()
table1 = Table('table1', metadata, autoload=True, schema=schema)
table2 = Table('table2', metadata, autoload=True, schema=schema)
- assert table1.schema == table2.schema == schema
- assert len(metadata.tables) == 2
metadata.drop_all()
-
if __name__ == "__main__":
testbase.main()
diff --git a/test/engine/transaction.py b/test/engine/transaction.py
index c89bf4b14..593a069a9 100644
--- a/test/engine/transaction.py
+++ b/test/engine/transaction.py
@@ -1,19 +1,19 @@
-
import testbase
-import unittest, sys, datetime
-import tables
-db = testbase.db
+import sys, time, threading
+
from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
-class TransactionTest(testbase.PersistTest):
+class TransactionTest(PersistTest):
def setUpAll(self):
global users, metadata
metadata = MetaData()
users = Table('query_users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20)),
- mysql_engine='InnoDB'
+ test_needs_acid=True,
)
users.create(testbase.db)
@@ -114,8 +114,154 @@ class TransactionTest(testbase.PersistTest):
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 0
connection.close()
+
+ @testing.unsupported('sqlite')
+ def testnestedsubtransactionrollback(self):
+ connection = testbase.db.connect()
+ transaction = connection.begin()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ trans2 = connection.begin_nested()
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ trans2.rollback()
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ transaction.commit()
+
+ self.assertEquals(
+ connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(3,)]
+ )
+ connection.close()
+
+ @testing.unsupported('sqlite')
+ def testnestedsubtransactioncommit(self):
+ connection = testbase.db.connect()
+ transaction = connection.begin()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ trans2 = connection.begin_nested()
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ trans2.commit()
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ transaction.commit()
+
+ self.assertEquals(
+ connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(2,),(3,)]
+ )
+ connection.close()
+
+ @testing.unsupported('sqlite')
+ def testrollbacktosubtransaction(self):
+ connection = testbase.db.connect()
+ transaction = connection.begin()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ trans2 = connection.begin_nested()
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ trans3 = connection.begin()
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ trans3.rollback()
+ connection.execute(users.insert(), user_id=4, user_name='user4')
+ transaction.commit()
+
+ self.assertEquals(
+ connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(4,)]
+ )
+ connection.close()
+
+ @testing.supported('postgres', 'mysql')
+ def testtwophasetransaction(self):
+ connection = testbase.db.connect()
+
+ transaction = connection.begin_twophase()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ transaction.prepare()
+ transaction.commit()
+
+ transaction = connection.begin_twophase()
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ transaction.commit()
+
+ transaction = connection.begin_twophase()
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ transaction.rollback()
+
+ transaction = connection.begin_twophase()
+ connection.execute(users.insert(), user_id=4, user_name='user4')
+ transaction.prepare()
+ transaction.rollback()
+
+ self.assertEquals(
+ connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(2,)]
+ )
+ connection.close()
+
+ @testing.supported('postgres', 'mysql')
+ def testmixedtransaction(self):
+ connection = testbase.db.connect()
+
+ transaction = connection.begin_twophase()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+
+ transaction2 = connection.begin()
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+
+ transaction3 = connection.begin_nested()
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+
+ transaction4 = connection.begin()
+ connection.execute(users.insert(), user_id=4, user_name='user4')
+ transaction4.commit()
+
+ transaction3.rollback()
+
+ connection.execute(users.insert(), user_id=5, user_name='user5')
+
+ transaction2.commit()
+
+ transaction.prepare()
+
+ transaction.commit()
+
+ self.assertEquals(
+ connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(2,),(5,)]
+ )
+ connection.close()
-class AutoRollbackTest(testbase.PersistTest):
+ @testing.supported('postgres')
+ def testtwophaserecover(self):
+ # MySQL recovery doesn't currently seem to work correctly
+ # Prepared transactions disappear when connections are closed and even
+ # when they aren't it doesn't seem possible to use the recovery id.
+ connection = testbase.db.connect()
+
+ transaction = connection.begin_twophase()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ transaction.prepare()
+
+ connection.close()
+ connection2 = testbase.db.connect()
+
+ self.assertEquals(
+ connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ []
+ )
+
+ recoverables = connection2.recover_twophase()
+ self.assertTrue(
+ transaction.xid in recoverables
+ )
+
+ connection2.commit_prepared(transaction.xid, recover=True)
+
+ self.assertEquals(
+ connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,)]
+ )
+ connection2.close()
+
+class AutoRollbackTest(PersistTest):
def setUpAll(self):
global metadata
metadata = MetaData()
@@ -123,7 +269,7 @@ class AutoRollbackTest(testbase.PersistTest):
def tearDownAll(self):
metadata.drop_all(testbase.db)
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testrollback_deadlock(self):
"""test that returning connections to the pool clears any object locks."""
conn1 = testbase.db.connect()
@@ -131,6 +277,7 @@ class AutoRollbackTest(testbase.PersistTest):
users = Table('deadlock_users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20)),
+ test_needs_acid=True,
)
users.create(conn1)
conn1.execute("select * from deadlock_users")
@@ -141,15 +288,15 @@ class AutoRollbackTest(testbase.PersistTest):
users.drop(conn2)
conn2.close()
-class TLTransactionTest(testbase.PersistTest):
+class TLTransactionTest(PersistTest):
def setUpAll(self):
global users, metadata, tlengine
- tlengine = create_engine(testbase.db_uri, strategy='threadlocal')
+ tlengine = create_engine(testbase.db.url, strategy='threadlocal')
metadata = MetaData()
users = Table('query_users', metadata,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20)),
- mysql_engine='InnoDB'
+ test_needs_acid=True,
)
users.create(tlengine)
def tearDown(self):
@@ -254,7 +401,7 @@ class TLTransactionTest(testbase.PersistTest):
finally:
external_connection.close()
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testnesting(self):
"""tests nesting of tranacstions"""
external_connection = tlengine.connect()
@@ -330,7 +477,7 @@ class TLTransactionTest(testbase.PersistTest):
try:
mapper(User, users)
- sess = create_session(bind_to=tlengine)
+ sess = create_session(bind=tlengine)
tlengine.begin()
u = User()
sess.save(u)
@@ -347,6 +494,127 @@ class TLTransactionTest(testbase.PersistTest):
assert c1.connection is c2.connection
c2.close()
assert c1.connection.connection is not None
+
+class ForUpdateTest(PersistTest):
+ def setUpAll(self):
+ global counters, metadata
+ metadata = MetaData()
+ counters = Table('forupdate_counters', metadata,
+ Column('counter_id', INT, primary_key = True),
+ Column('counter_value', INT),
+ test_needs_acid=True,
+ )
+ counters.create(testbase.db)
+ def tearDown(self):
+ testbase.db.connect().execute(counters.delete())
+ def tearDownAll(self):
+ counters.drop(testbase.db)
+
+ def increment(self, count, errors, update_style=True, delay=0.005):
+ con = testbase.db.connect()
+ sel = counters.select(for_update=update_style,
+ whereclause=counters.c.counter_id==1)
+
+ for i in xrange(count):
+ trans = con.begin()
+ try:
+ existing = con.execute(sel).fetchone()
+ incr = existing['counter_value'] + 1
+
+ time.sleep(delay)
+ con.execute(counters.update(counters.c.counter_id==1,
+ values={'counter_value':incr}))
+ time.sleep(delay)
+
+ readback = con.execute(sel).fetchone()
+ if (readback['counter_value'] != incr):
+ raise AssertionError("Got %s post-update, expected %s" %
+ (readback['counter_value'], incr))
+ trans.commit()
+ except Exception, e:
+ trans.rollback()
+ errors.append(e)
+ break
+
+ con.close()
+
+ @testing.supported('mysql', 'oracle', 'postgres')
+ def testqueued_update(self):
+ """Test SELECT FOR UPDATE with concurrent modifications.
+
+ Runs concurrent modifications on a single row in the users table,
+ with each mutator trying to increment a value stored in user_name.
+ """
+
+ db = testbase.db
+ db.execute(counters.insert(), counter_id=1, counter_value=0)
+
+ iterations, thread_count = 10, 5
+ threads, errors = [], []
+ for i in xrange(thread_count):
+ thread = threading.Thread(target=self.increment,
+ args=(iterations,),
+ kwargs={'errors': errors,
+ 'update_style': True})
+ thread.start()
+ threads.append(thread)
+ for thread in threads:
+ thread.join()
+
+ for e in errors:
+ sys.stderr.write("Failure: %s\n" % e)
+
+ self.assert_(len(errors) == 0)
+
+ sel = counters.select(whereclause=counters.c.counter_id==1)
+ final = db.execute(sel).fetchone()
+ self.assert_(final['counter_value'] == iterations * thread_count)
+
+ def overlap(self, ids, errors, update_style):
+ sel = counters.select(for_update=update_style,
+ whereclause=counters.c.counter_id.in_(*ids))
+ con = testbase.db.connect()
+ trans = con.begin()
+ try:
+ rows = con.execute(sel).fetchall()
+ time.sleep(0.25)
+ trans.commit()
+ except Exception, e:
+ trans.rollback()
+ errors.append(e)
+
+ def _threaded_overlap(self, thread_count, groups, update_style=True, pool=5):
+ db = testbase.db
+ for cid in range(pool - 1):
+ db.execute(counters.insert(), counter_id=cid + 1, counter_value=0)
+
+ errors, threads = [], []
+ for i in xrange(thread_count):
+ thread = threading.Thread(target=self.overlap,
+ args=(groups.pop(0), errors, update_style))
+ thread.start()
+ threads.append(thread)
+ for thread in threads:
+ thread.join()
+
+ return errors
+
+ @testing.supported('mysql', 'oracle', 'postgres')
+ def testqueued_select(self):
+ """Simple SELECT FOR UPDATE conflict test"""
+
+ errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)])
+ for e in errors:
+ sys.stderr.write("Failure: %s\n" % e)
+ self.assert_(len(errors) == 0)
+
+ @testing.supported('oracle', 'postgres')
+ def testnowait_select(self):
+ """Simple SELECT FOR UPDATE NOWAIT conflict test"""
+
+ errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)],
+ update_style='nowait')
+ self.assert_(len(errors) != 0)
if __name__ == "__main__":
testbase.main()
diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py
index ebb832fdc..e28c72cd7 100644
--- a/test/ext/activemapper.py
+++ b/test/ext/activemapper.py
@@ -1,16 +1,18 @@
import testbase
+from datetime import datetime
+
from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore
-from sqlalchemy import and_, or_, clear_mappers, backref, create_session, exceptions
+from sqlalchemy import and_, or_, exceptions
from sqlalchemy import ForeignKey, String, Integer, DateTime, Table, Column
-from datetime import datetime
-import sqlalchemy
-
+from sqlalchemy.orm import clear_mappers, backref, create_session, class_mapper
import sqlalchemy.ext.activemapper as activemapper
+import sqlalchemy
+from testlib import *
-class testcase(testbase.PersistTest):
+class testcase(PersistTest):
def setUpAll(self):
- sqlalchemy.clear_mappers()
+ clear_mappers()
objectstore.clear()
global Person, Preferences, Address
@@ -133,7 +135,7 @@ class testcase(testbase.PersistTest):
objectstore.flush()
objectstore.clear()
- results = Person.select()
+ results = Person.query.select()
self.assertEquals(len(results), 1)
@@ -142,30 +144,30 @@ class testcase(testbase.PersistTest):
self.assertEquals(len(person.addresses), 2)
self.assertEquals(person.addresses[0].postal_code, '30338')
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_update(self):
p1 = self.create_person_one()
objectstore.flush()
objectstore.clear()
- person = Person.select()[0]
+ person = Person.query.select()[0]
person.gender = 'F'
objectstore.flush()
objectstore.clear()
self.assertEquals(person.row_version, 2)
- person = Person.select()[0]
+ person = Person.query.select()[0]
person.gender = 'M'
objectstore.flush()
objectstore.clear()
self.assertEquals(person.row_version, 3)
#TODO: check that a concurrent modification raises exception
- p1 = Person.select()[0]
+ p1 = Person.query.select()[0]
s1 = objectstore.session
s2 = create_session()
objectstore.context.current = s2
- p2 = Person.select()[0]
+ p2 = Person.query.select()[0]
p1.first_name = "jack"
p2.first_name = "ed"
objectstore.flush()
@@ -185,14 +187,14 @@ class testcase(testbase.PersistTest):
objectstore.flush()
objectstore.clear()
- results = Person.select()
+ results = Person.query.select()
self.assertEquals(len(results), 1)
results[0].delete()
objectstore.flush()
objectstore.clear()
- results = Person.select()
+ results = Person.query.select()
self.assertEquals(len(results), 0)
@@ -204,7 +206,7 @@ class testcase(testbase.PersistTest):
objectstore.clear()
# select and make sure we get back two results
- people = Person.select()
+ people = Person.query.select()
self.assertEquals(len(people), 2)
# make sure that our backwards relationships work
@@ -212,7 +214,7 @@ class testcase(testbase.PersistTest):
self.assertEquals(people[1].addresses[0].person.id, p2.id)
# try a more complex select
- results = Person.select(
+ results = Person.query.select(
or_(
and_(
Address.c.person_id == Person.c.id,
@@ -253,17 +255,16 @@ class testcase(testbase.PersistTest):
objectstore.flush()
objectstore.clear()
- results = Person.select(
- Address.c.postal_code.like('30075') &
- Person.join_to('addresses')
+ results = Person.query.join('addresses').select(
+ Address.c.postal_code.like('30075')
)
self.assertEquals(len(results), 1)
- self.assertEquals(Person.count(), 2)
+ self.assertEquals(Person.query.count(), 2)
-class testmanytomany(testbase.PersistTest):
+class testmanytomany(PersistTest):
def setUpAll(self):
- sqlalchemy.clear_mappers()
+ clear_mappers()
objectstore.clear()
global secondarytable, foo, baz
secondarytable = Table("secondarytable",
@@ -299,8 +300,8 @@ class testmanytomany(testbase.PersistTest):
objectstore.flush()
objectstore.clear()
- foo1 = foo.get_by(name='foo1')
- baz1 = baz.get_by(name='baz1')
+ foo1 = foo.query.get_by(name='foo1')
+ baz1 = baz.query.get_by(name='baz1')
# Just checking ...
assert (foo1.name == 'foo1')
@@ -313,14 +314,12 @@ class testmanytomany(testbase.PersistTest):
# Optimistically based on activemapper one_to_many test, try to append
# baz1 to foo1.bazrel - (AttributeError: 'foo' object has no attribute 'bazrel')
- print sqlalchemy.class_mapper(foo).props
- print sqlalchemy.class_mapper(baz).props
foo1.bazrel.append(baz1)
assert (foo1.bazrel == [baz1])
-class testselfreferential(testbase.PersistTest):
+class testselfreferential(PersistTest):
def setUpAll(self):
- sqlalchemy.clear_mappers()
+ clear_mappers()
objectstore.clear()
global TreeNode
class TreeNode(activemapper.ActiveMapper):
@@ -343,15 +342,15 @@ class testselfreferential(testbase.PersistTest):
objectstore.flush()
objectstore.clear()
- t = TreeNode.get_by(name='node1')
+ t = TreeNode.query.get_by(name='node1')
assert (t.name == 'node1')
assert (t.children[0].name == 'node2')
assert (t.children[1].name == 'node3')
assert (t.children[1].parent is t)
objectstore.clear()
- t = TreeNode.get_by(name='node3')
- assert (t.parent is TreeNode.get_by(name='node1'))
+ t = TreeNode.query.get_by(name='node3')
+ assert (t.parent is TreeNode.query.get_by(name='node1'))
if __name__ == '__main__':
testbase.main()
diff --git a/test/ext/alltests.py b/test/ext/alltests.py
index 713601c3b..589f0f68f 100644
--- a/test/ext/alltests.py
+++ b/test/ext/alltests.py
@@ -3,7 +3,6 @@ import unittest, doctest
def suite():
unittest_modules = ['ext.activemapper',
- 'ext.selectresults',
'ext.assignmapper',
'ext.orderinglist',
'ext.associationproxy']
diff --git a/test/ext/assignmapper.py b/test/ext/assignmapper.py
index 650994987..31b3dd576 100644
--- a/test/ext/assignmapper.py
+++ b/test/ext/assignmapper.py
@@ -1,12 +1,13 @@
-from testbase import PersistTest, AssertMixin
import testbase
from sqlalchemy import *
-
+from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper
from sqlalchemy.ext.assignmapper import assign_mapper
from sqlalchemy.ext.sessioncontext import SessionContext
+from testlib import *
+
-class OverrideAttributesTest(PersistTest):
+class AssignMapperTest(PersistTest):
def setUpAll(self):
global metadata, table, table2
metadata = MetaData(testbase.db)
@@ -18,25 +19,18 @@ class OverrideAttributesTest(PersistTest):
Column('someid', None, ForeignKey('sometable.id'))
)
metadata.create_all()
- def tearDownAll(self):
- metadata.drop_all()
- def tearDown(self):
- clear_mappers()
+
def setUp(self):
- pass
- def test_override_attributes(self):
+ global SomeObject, SomeOtherObject, ctx
class SomeObject(object):pass
class SomeOtherObject(object):pass
ctx = SessionContext(create_session)
assign_mapper(ctx, SomeObject, table, properties={
- # this is the current workaround for class attribute name/collection collision: specify collection_class
- # explicitly. when we do away with class attributes specifying collection classes, this wont be
- # needed anymore.
- 'options':relation(SomeOtherObject, collection_class=list)
+ 'options':relation(SomeOtherObject)
})
assign_mapper(ctx, SomeOtherObject, table2)
- class_mapper(SomeObject)
+
s = SomeObject()
s.id = 1
s.data = 'hello'
@@ -44,8 +38,42 @@ class OverrideAttributesTest(PersistTest):
s.options.append(sso)
ctx.current.flush()
ctx.current.clear()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+ def tearDown(self):
+ for table in metadata.table_iterator(reverse=True):
+ table.delete().execute()
+ clear_mappers()
+
+ def test_override_attributes(self):
+
+ sso = SomeOtherObject.query().first()
- assert SomeObject.get_by(id=s.id).options[0].id == sso.id
+ assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+
+ s2 = SomeObject(someid=12)
+ s3 = SomeOtherObject(someid=123, bogus=345)
+
+ class ValidatedOtherObject(object):pass
+ assign_mapper(ctx, ValidatedOtherObject, table2, validate=True)
+
+ v1 = ValidatedOtherObject(someid=12)
+ try:
+ v2 = ValidatedOtherObject(someid=12, bogus=345)
+ assert False
+ except exceptions.ArgumentError:
+ pass
+
+ def test_dont_clobber_methods(self):
+ class MyClass(object):
+ def expunge(self):
+ return "an expunge !"
+
+ assign_mapper(ctx, MyClass, table2)
+
+ assert MyClass().expunge() == "an expunge !"
+
if __name__ == '__main__':
testbase.main()
diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py
index 3b18581bb..f602871c2 100644
--- a/test/ext/associationproxy.py
+++ b/test/ext/associationproxy.py
@@ -1,17 +1,19 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest
import testbase
+
from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.orm.collections import collection
from sqlalchemy.ext.associationproxy import *
+from testlib import *
-db = testbase.db
class DictCollection(dict):
+ @collection.appender
def append(self, obj):
self[obj.foo] = obj
- def __iter__(self):
- return self.itervalues()
+ @collection.remover
+ def remove(self, obj):
+ del self[obj.foo]
class SetCollection(set):
pass
@@ -22,18 +24,20 @@ class ListCollection(list):
class ObjectCollection(object):
def __init__(self):
self.values = list()
+ @collection.appender
def append(self, obj):
self.values.append(obj)
+ @collection.remover
+ def remove(self, obj):
+ self.values.remove(obj)
def __iter__(self):
return iter(self.values)
- def clear(self):
- self.values.clear()
class _CollectionOperations(PersistTest):
def setUp(self):
collection_class = self.collection_class
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
parents_table = Table('Parent', metadata,
Column('id', Integer, primary_key=True),
@@ -131,9 +135,57 @@ class _CollectionOperations(PersistTest):
self.assert_(len(p1._children) == 3)
self.assert_(len(p1.children) == 3)
+ popped = p1.children.pop()
+ self.assert_(len(p1.children) == 2)
+ self.assert_(popped not in p1.children)
+ p1 = self.roundtrip(p1)
+ self.assert_(len(p1.children) == 2)
+ self.assert_(popped not in p1.children)
+
+ p1.children[1] = 'changed-in-place'
+ self.assert_(p1.children[1] == 'changed-in-place')
+ inplace_id = p1._children[1].id
+ p1 = self.roundtrip(p1)
+ self.assert_(p1.children[1] == 'changed-in-place')
+ assert p1._children[1].id == inplace_id
+
+ p1.children.append('changed-in-place')
+ self.assert_(p1.children.count('changed-in-place') == 2)
+
+ p1.children.remove('changed-in-place')
+ self.assert_(p1.children.count('changed-in-place') == 1)
+
+ p1 = self.roundtrip(p1)
+ self.assert_(p1.children.count('changed-in-place') == 1)
+
p1._children = []
self.assert_(len(p1.children) == 0)
+ after = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
+ p1.children = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
+ self.assert_(len(p1.children) == 10)
+ self.assert_([c.name for c in p1._children] == after)
+
+ p1.children[2:6] = ['x'] * 4
+ after = ['a', 'b', 'x', 'x', 'x', 'x', 'g', 'h', 'i', 'j']
+ self.assert_(p1.children == after)
+ self.assert_([c.name for c in p1._children] == after)
+
+ p1.children[2:6] = ['y']
+ after = ['a', 'b', 'y', 'g', 'h', 'i', 'j']
+ self.assert_(p1.children == after)
+ self.assert_([c.name for c in p1._children] == after)
+
+ p1.children[2:3] = ['z'] * 4
+ after = ['a', 'b', 'z', 'z', 'z', 'z', 'g', 'h', 'i', 'j']
+ self.assert_(p1.children == after)
+ self.assert_([c.name for c in p1._children] == after)
+
+ p1.children[2::2] = ['O'] * 4
+ after = ['a', 'b', 'O', 'z', 'O', 'z', 'O', 'h', 'O', 'j']
+ self.assert_(p1.children == after)
+ self.assert_([c.name for c in p1._children] == after)
+
class DefaultTest(_CollectionOperations):
def __init__(self, *args, **kw):
super(DefaultTest, self).__init__(*args, **kw)
@@ -218,12 +270,27 @@ class CustomDictTest(DictTest):
self.assert_(len(p1._children) == 3)
self.assert_(len(p1.children) == 3)
- p1.children['d'] = 'new d'
- assert p1.children['d'] == 'new d'
+ p1.children['e'] = 'changed-in-place'
+ self.assert_(p1.children['e'] == 'changed-in-place')
+ inplace_id = p1._children['e'].id
+ p1 = self.roundtrip(p1)
+ self.assert_(p1.children['e'] == 'changed-in-place')
+ self.assert_(p1._children['e'].id == inplace_id)
p1._children = {}
self.assert_(len(p1.children) == 0)
+ try:
+ p1._children = []
+ self.assert_(False)
+ except exceptions.ArgumentError:
+ self.assert_(True)
+
+ try:
+ p1._children = None
+ self.assert_(False)
+ except exceptions.ArgumentError:
+ self.assert_(True)
class SetTest(_CollectionOperations):
def __init__(self, *args, **kw):
@@ -239,7 +306,7 @@ class SetTest(_CollectionOperations):
self.assert_(not p1.children)
ch1 = Child('regular')
- p1._children.append(ch1)
+ p1._children.add(ch1)
self.assert_(ch1 in p1._children)
self.assert_(len(p1._children) == 1)
@@ -256,7 +323,8 @@ class SetTest(_CollectionOperations):
self.assert_(len(p1.children) == 2)
self.assert_(len(p1._children) == 2)
- self.assert_(set([o.name for o in p1._children]) == set(['regular', 'proxied']))
+ self.assert_(set([o.name for o in p1._children]) ==
+ set(['regular', 'proxied']))
ch2 = None
for o in p1._children:
@@ -322,9 +390,22 @@ class SetTest(_CollectionOperations):
p1 = self.roundtrip(p1)
self.assert_(p1.children == set(['c']))
- p1._children = []
+ p1._children = set()
self.assert_(len(p1.children) == 0)
+ try:
+ p1._children = []
+ self.assert_(False)
+ except exceptions.ArgumentError:
+ self.assert_(True)
+
+ try:
+ p1._children = None
+ self.assert_(False)
+ except exceptions.ArgumentError:
+ self.assert_(True)
+
+
def test_set_comparisons(self):
Parent, Child = self.Parent, self.Child
@@ -393,14 +474,7 @@ class SetTest(_CollectionOperations):
print 'want', repr(control)
print 'got', repr(p.children)
raise
-
- # workaround for bug #548
- def test_set_pop(self):
- Parent, Child = self.Parent, self.Child
- p = Parent('p1')
- p.children.add('a')
- p.children.pop()
- self.assert_(True)
+
class CustomSetTest(SetTest):
def __init__(self, *args, **kw):
@@ -434,7 +508,7 @@ class CustomObjectTest(_CollectionOperations):
class ScalarTest(PersistTest):
def test_scalar_proxy(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
parents_table = Table('Parent', metadata,
Column('id', Integer, primary_key=True),
@@ -550,7 +624,7 @@ class ScalarTest(PersistTest):
class LazyLoadTest(PersistTest):
def setUp(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
parents_table = Table('Parent', metadata,
Column('id', Integer, primary_key=True),
@@ -606,7 +680,7 @@ class LazyLoadTest(PersistTest):
# Is there a better way to ensure that the association_proxy
# didn't convert a lazy load to an eager load? This does work though.
self.assert_('_children' not in p.__dict__)
- self.assert_(len(p._children.data) == 3)
+ self.assert_(len(p._children) == 3)
self.assert_('_children' in p.__dict__)
def test_eager_list(self):
@@ -622,7 +696,7 @@ class LazyLoadTest(PersistTest):
p = self.roundtrip(p)
self.assert_('_children' in p.__dict__)
- self.assert_(len(p._children.data) == 3)
+ self.assert_(len(p._children) == 3)
def test_lazy_scalar(self):
Parent, Child = self.Parent, self.Child
diff --git a/test/ext/legacy_objectstore.py b/test/ext/legacy_objectstore.py
deleted file mode 100644
index 3aa99a1ae..000000000
--- a/test/ext/legacy_objectstore.py
+++ /dev/null
@@ -1,113 +0,0 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import StringIO
-import testbase
-
-from tables import *
-import tables
-
-install_mods('legacy_session')
-
-
-class LegacySessionTest(AssertMixin):
- def setUpAll(self):
- db.echo = False
- users.create()
- db.echo = testbase.echo
- def tearDownAll(self):
- db.echo = False
- users.drop()
- db.echo = testbase.echo
- def setUp(self):
- objectstore.get_session().clear()
- clear_mappers()
- tables.user_data()
- #db.echo = "debug"
- def tearDown(self):
- tables.delete_user_data()
-
- def test_nested_begin_commit(self):
- """tests that nesting objectstore transactions with multiple commits
- affects only the outermost transaction"""
- class User(object):pass
- m = mapper(User, users)
- def name_of(id):
- return users.select(users.c.user_id == id).execute().fetchone().user_name
- name1 = "Oliver Twist"
- name2 = 'Mr. Bumble'
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- s = objectstore.get_session()
- trans = s.begin()
- trans2 = s.begin()
- m.get(7).user_name = name1
- trans3 = s.begin()
- m.get(8).user_name = name2
- trans3.commit()
- s.commit() # should do nothing
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans2.commit()
- s.commit() # should do nothing
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans.commit()
- self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1)
- self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2)
-
- def test_nested_rollback(self):
- """tests that nesting objectstore transactions with a rollback inside
- affects only the outermost transaction"""
- class User(object):pass
- m = mapper(User, users)
- def name_of(id):
- return users.select(users.c.user_id == id).execute().fetchone().user_name
- name1 = "Oliver Twist"
- name2 = 'Mr. Bumble'
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- s = objectstore.get_session()
- trans = s.begin()
- trans2 = s.begin()
- m.get(7).user_name = name1
- trans3 = s.begin()
- m.get(8).user_name = name2
- trans3.rollback()
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans2.commit()
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans.commit()
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
-
- def test_true_nested(self):
- """tests creating a new Session inside a database transaction, in
- conjunction with an engine-level nested transaction, which uses
- a second connection in order to achieve a nested transaction that commits, inside
- of another engine session that rolls back."""
-# testbase.db.echo='debug'
- class User(object):
- pass
- testbase.db.begin()
- try:
- m = mapper(User, users)
- name1 = "Oliver Twist"
- name2 = 'Mr. Bumble'
- m.get(7).user_name = name1
- s = objectstore.Session(nest_on=testbase.db)
- m.using(s).get(8).user_name = name2
- s.commit()
- objectstore.commit()
- testbase.db.rollback()
- except:
- testbase.db.rollback()
- raise
- objectstore.clear()
- self.assert_(m.get(8).user_name == name2)
- self.assert_(m.get(7).user_name != name1)
-
-if __name__ == "__main__":
- testbase.main()
diff --git a/test/ext/orderinglist.py b/test/ext/orderinglist.py
index 6dcf057d4..d16e20da7 100644
--- a/test/ext/orderinglist.py
+++ b/test/ext/orderinglist.py
@@ -1,11 +1,10 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest, sys, os
import testbase
+
from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.ext.orderinglist import *
+from testlib import *
-db = testbase.db
metadata = None
# order in whole steps
@@ -52,7 +51,7 @@ class OrderingListTest(PersistTest):
global metadata, slides_table, bullets_table, Slide, Bullet
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
slides_table = Table('test_Slides', metadata,
Column('id', Integer, primary_key=True),
Column('name', String))
@@ -297,43 +296,7 @@ class OrderingListTest(PersistTest):
self.assert_(srt.bullets[i].position == i)
self.assert_(srt.bullets[i].text == text)
- def test_replace1(self):
- self._setup(ordering_list('position'))
-
- s1 = Slide('Slide #1')
- s1.bullets = [ Bullet('1'), Bullet('2'), Bullet('3') ]
-
- self.assert_(len(s1.bullets) == 3)
- self.assert_(s1.bullets[2].position == 2)
-
- session = create_session()
- session.save(s1)
- session.flush()
-
- new_bullet = Bullet('new 2')
- self.assert_(new_bullet.position is None)
-
- # naive replacement, no database deletion should occur
- # with current InstrumentedList __setitem__ semantics
- s1.bullets[1] = new_bullet
-
- self.assert_(new_bullet.position == 1)
- self.assert_(len(s1.bullets) == 3)
-
- id = s1.id
-
- session.flush()
- session.clear()
-
- srt = session.query(Slide).get(id)
-
- self.assert_(srt.bullets)
- self.assert_(len(srt.bullets) == 4)
-
- self.assert_(srt.bullets[1].text == '2')
- self.assert_(srt.bullets[2].text == 'new 2')
-
- def test_replace2(self):
+ def test_replace(self):
self._setup(ordering_list('position'))
s1 = Slide('Slide #1')
@@ -350,7 +313,7 @@ class OrderingListTest(PersistTest):
self.assert_(new_bullet.position is None)
# mark existing bullet as db-deleted before replacement.
- session.delete(s1.bullets[1])
+ #session.delete(s1.bullets[1])
s1.bullets[1] = new_bullet
self.assert_(new_bullet.position == 1)
diff --git a/test/ext/selectresults.py b/test/ext/selectresults.py
deleted file mode 100644
index 1ec724c3d..000000000
--- a/test/ext/selectresults.py
+++ /dev/null
@@ -1,239 +0,0 @@
-from testbase import PersistTest, AssertMixin
-import testbase
-import tables
-
-from sqlalchemy import *
-
-from sqlalchemy.ext.selectresults import SelectResultsExt, SelectResults
-
-class Foo(object):
- pass
-
-class SelectResultsTest(PersistTest):
- def setUpAll(self):
- self.install_threadlocal()
- global foo, metadata
- metadata = MetaData(testbase.db)
- foo = Table('foo', metadata,
- Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
- Column('bar', Integer),
- Column('range', Integer))
-
- assign_mapper(Foo, foo, extension=SelectResultsExt())
- metadata.create_all()
- for i in range(100):
- Foo(bar=i, range=i%10)
- objectstore.flush()
-
- def setUp(self):
- self.query = Query(Foo)
- self.orig = self.query.select_whereclause()
- self.res = self.query.select()
-
- def tearDownAll(self):
- metadata.drop_all()
- self.uninstall_threadlocal()
- clear_mappers()
-
- def test_selectby(self):
- res = self.query.select_by(range=5)
- assert res.order_by([Foo.c.bar])[0].bar == 5
- assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
-
- @testbase.unsupported('mssql')
- def test_slice(self):
- assert self.res[1] == self.orig[1]
- assert list(self.res[10:20]) == self.orig[10:20]
- assert list(self.res[10:]) == self.orig[10:]
- assert list(self.res[:10]) == self.orig[:10]
- assert list(self.res[:10]) == self.orig[:10]
- assert list(self.res[10:40:3]) == self.orig[10:40:3]
- assert list(self.res[-5:]) == self.orig[-5:]
- assert self.res[10:20][5] == self.orig[10:20][5]
-
- @testbase.supported('mssql')
- def test_slice_mssql(self):
- assert list(self.res[:10]) == self.orig[:10]
- assert list(self.res[:10]) == self.orig[:10]
-
- def test_aggregate(self):
- assert self.res.count() == 100
- assert self.res.filter(foo.c.bar<30).min(foo.c.bar) == 0
- assert self.res.filter(foo.c.bar<30).max(foo.c.bar) == 29
-
- @testbase.unsupported('mysql')
- def test_aggregate_1(self):
- # this one fails in mysql as the result comes back as a string
- assert self.res.filter(foo.c.bar<30).sum(foo.c.bar) == 435
-
- @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
- def test_aggregate_2(self):
- assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
-
- @testbase.supported('postgres', 'mysql', 'firebird', 'mssql')
- def test_aggregate_2_int(self):
- assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
-
- def test_filter(self):
- assert self.res.count() == 100
- assert self.res.filter(Foo.c.bar < 30).count() == 30
- res2 = self.res.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
- assert res2.count() == 19
-
- def test_options(self):
- class ext1(MapperExtension):
- def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
- instance.TEST = "hello world"
- return EXT_PASS
- objectstore.clear()
- assert self.res.options(extension(ext1()))[0].TEST == "hello world"
-
- def test_order_by(self):
- assert self.res.order_by([Foo.c.bar])[0].bar == 0
- assert self.res.order_by([desc(Foo.c.bar)])[0].bar == 99
-
- def test_offset(self):
- assert list(self.res.order_by([Foo.c.bar]).offset(10))[0].bar == 10
-
- def test_offset(self):
- assert len(list(self.res.limit(10))) == 10
-
-class Obj1(object):
- pass
-class Obj2(object):
- pass
-
-class SelectResultsTest2(PersistTest):
- def setUpAll(self):
- self.install_threadlocal()
- global metadata, table1, table2
- metadata = MetaData(testbase.db)
- table1 = Table('Table1', metadata,
- Column('id', Integer, primary_key=True),
- )
- table2 = Table('Table2', metadata,
- Column('t1id', Integer, ForeignKey("Table1.id"), primary_key=True),
- Column('num', Integer, primary_key=True),
- )
- assign_mapper(Obj1, table1, extension=SelectResultsExt())
- assign_mapper(Obj2, table2, extension=SelectResultsExt())
- metadata.create_all()
- table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4})
- table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
-{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3})
-
- def setUp(self):
- self.query = Query(Obj1)
- #self.orig = self.query.select_whereclause()
- #self.res = self.query.select()
-
- def tearDownAll(self):
- metadata.drop_all()
- self.uninstall_threadlocal()
- clear_mappers()
-
- def test_distinctcount(self):
- res = self.query.select()
- assert res.count() == 4
- res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
- assert res.count() == 3
- res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1), distinct=True)
- self.assertEqual(res.count(), 1)
-
-class RelationsTest(AssertMixin):
- def setUpAll(self):
- tables.create()
- tables.data()
- def tearDownAll(self):
- tables.drop()
- def tearDown(self):
- clear_mappers()
- def test_jointo(self):
- """test the join_to and outerjoin_to functions on SelectResults"""
- mapper(tables.User, tables.users, properties={
- 'orders':relation(mapper(tables.Order, tables.orders, properties={
- 'items':relation(mapper(tables.Item, tables.orderitems))
- }))
- })
- session = create_session()
- query = SelectResults(session.query(tables.User))
- x = query.join_to('orders').join_to('items').select(tables.Item.c.item_id==2)
- print x.compile()
- self.assert_result(list(x), tables.User, tables.user_result[2])
- def test_outerjointo(self):
- """test the join_to and outerjoin_to functions on SelectResults"""
- mapper(tables.User, tables.users, properties={
- 'orders':relation(mapper(tables.Order, tables.orders, properties={
- 'items':relation(mapper(tables.Item, tables.orderitems))
- }))
- })
- session = create_session()
- query = SelectResults(session.query(tables.User))
- x = query.outerjoin_to('orders').outerjoin_to('items').select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
- print x.compile()
- self.assert_result(list(x), tables.User, *tables.user_result[1:3])
- def test_outerjointo_count(self):
- """test the join_to and outerjoin_to functions on SelectResults"""
- mapper(tables.User, tables.users, properties={
- 'orders':relation(mapper(tables.Order, tables.orders, properties={
- 'items':relation(mapper(tables.Item, tables.orderitems))
- }))
- })
- session = create_session()
- query = SelectResults(session.query(tables.User))
- x = query.outerjoin_to('orders').outerjoin_to('items').select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
- assert x==2
- def test_from(self):
- mapper(tables.User, tables.users, properties={
- 'orders':relation(mapper(tables.Order, tables.orders, properties={
- 'items':relation(mapper(tables.Item, tables.orderitems))
- }))
- })
- session = create_session()
- query = SelectResults(session.query(tables.User))
- x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\
- filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
- print x.compile()
- self.assert_result(list(x), tables.User, *tables.user_result[1:3])
-
-
-class CaseSensitiveTest(PersistTest):
- def setUpAll(self):
- self.install_threadlocal()
- global metadata, table1, table2
- metadata = MetaData(testbase.db)
- table1 = Table('Table1', metadata,
- Column('ID', Integer, primary_key=True),
- )
- table2 = Table('Table2', metadata,
- Column('T1ID', Integer, ForeignKey("Table1.ID"), primary_key=True),
- Column('NUM', Integer, primary_key=True),
- )
- assign_mapper(Obj1, table1, extension=SelectResultsExt())
- assign_mapper(Obj2, table2, extension=SelectResultsExt())
- metadata.create_all()
- table1.insert().execute({'ID':1},{'ID':2},{'ID':3},{'ID':4})
- table2.insert().execute({'NUM':1,'T1ID':1},{'NUM':2,'T1ID':1},{'NUM':3,'T1ID':1},\
-{'NUM':4,'T1ID':2},{'NUM':5,'T1ID':2},{'NUM':6,'T1ID':3})
-
- def setUp(self):
- self.query = Query(Obj1)
- #self.orig = self.query.select_whereclause()
- #self.res = self.query.select()
-
- def tearDownAll(self):
- metadata.drop_all()
- self.uninstall_threadlocal()
- clear_mappers()
-
- def test_distinctcount(self):
- res = self.query.select()
- assert res.count() == 4
- res = self.query.select(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1))
- assert res.count() == 3
- res = self.query.select(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1), distinct=True)
- self.assertEqual(res.count(), 1)
-
-
-if __name__ == "__main__":
- testbase.main()
diff --git a/test/ext/wsgi_test.py b/test/ext/wsgi_test.py
deleted file mode 100644
index 1330f88b6..000000000
--- a/test/ext/wsgi_test.py
+++ /dev/null
@@ -1,122 +0,0 @@
-"""Interactive wsgi test
-
-Small WSGI application that uses a table and mapper defined at the module
-level, with per-application uris enabled by the ProxyEngine.
-
-Requires the wsgiutils package from:
-
-http://www.owlfish.com/software/wsgiutils/
-
-Run the script with python wsgi_test.py, then visit http://localhost:8080/a
-and http://localhost:8080/b with a browser. You should see two distinct lists
-of colors.
-"""
-
-from sqlalchemy import *
-from sqlalchemy.ext.proxy import ProxyEngine
-from wsgiutils import wsgiServer
-
-engine = ProxyEngine()
-
-colors = Table('colors', engine,
- Column('id', Integer, primary_key=True),
- Column('name', String(32)),
- Column('hex', String(6)))
-
-class Color(object):
- pass
-
-assign_mapper(Color, colors)
-
-data = { 'a': (('fff','white'), ('aaa','gray'), ('000','black'),
- ('f00', 'red'), ('0f0', 'green')),
- 'b': (('00f','blue'), ('ff0', 'yellow'), ('0ff','purple')) }
-
-db_uri = { 'a': 'sqlite://filename=wsgi_db_a.db',
- 'b': 'sqlite://filename=wsgi_db_b.db' }
-
-def app(dataset):
- print '... connecting to database %s: %s' % (dataset, db_uri[dataset])
- engine.connect(db_uri[dataset], echo=True, echo_pool=True)
- colors.create()
-
- print '... populating data into %s' % db_uri[dataset]
- for hex, name in data[dataset]:
- c = Color()
- c.hex = hex
- c.name = name
- objectstore.commit()
- objectstore.clear()
-
- def call(environ, start_response):
- engine.connect(db_uri[dataset], echo=True, echo_pool=True)
-
- # NOTE: must clear objectstore on each request, or you'll see
- # objects from another thread here
- objectstore.clear()
- objectstore.begin()
-
- c = Color.select()
-
- start_response('200 OK', [('content-type','text/html')])
- yield '<html><head><title>Test dataset %s</title></head>' % dataset
- yield '<body>'
- yield '<p>uri: %s</p>' % db_uri[dataset]
- yield '<p>engine: <xmp>%s</xmp></p>' % engine.engine
- yield '<p>Colors!</p>'
- for color in c:
- yield '<div style="background: #%s">%s</div>' % (color.hex,
- color.name)
- yield '</body></html>'
- return call
-
-def cleanup():
- for uri in db_uri.values():
- print "Cleaning db %s" % uri
- engine.connect(uri)
- colors.drop()
-
-def run_server(apps, host='localhost', port=8080):
- print "Serving test app at http://%s:%s/" % (host, port)
- print "Visit http://%(host)s:%(port)s/a and " \
- "http://%(host)s:%(port)s/b to test apps" % {'host': host,
- 'port': port}
-
- server = wsgiServer.WSGIServer((host, port), apps, serveFiles=False)
- try:
- server.serve_forever()
- except:
- cleanup()
- raise
-
-if __name__ == '__main__':
- run_server({'/a':app('a'), '/b':app('b')})
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/test/orm/alltests.py b/test/orm/alltests.py
index 35650c136..4f8f4b6b7 100644
--- a/test/orm/alltests.py
+++ b/test/orm/alltests.py
@@ -1,18 +1,20 @@
import testbase
import unittest
+import inheritance.alltests as inheritance
+import sharding.alltests as sharding
+
def suite():
modules_to_test = (
- 'orm.attributes',
- 'orm.mapper',
+ 'orm.attributes',
'orm.query',
'orm.lazy_relations',
'orm.eager_relations',
+ 'orm.mapper',
+ 'orm.collection',
'orm.generative',
'orm.lazytest1',
- 'orm.eagertest1',
- 'orm.eagertest2',
- 'orm.eagertest3',
+ 'orm.assorted_eager',
'orm.sessioncontext',
'orm.unitofwork',
@@ -24,20 +26,11 @@ def suite():
'orm.memusage',
'orm.cycles',
- 'orm.poly_linked_list',
'orm.entity',
'orm.compile',
'orm.manytomany',
'orm.onetoone',
- 'orm.inheritance',
- 'orm.inheritance2',
- 'orm.inheritance3',
- 'orm.inheritance4',
- 'orm.inheritance5',
- 'orm.abc_inheritance',
- 'orm.single',
- 'orm.polymorph'
)
alltests = unittest.TestSuite()
for name in modules_to_test:
@@ -45,6 +38,8 @@ def suite():
for token in name.split('.')[1:]:
mod = getattr(mod, token)
alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
+ alltests.addTest(inheritance.suite())
+ alltests.addTest(sharding.suite())
return alltests
diff --git a/test/orm/association.py b/test/orm/association.py
index 416cfabbb..a2b899418 100644
--- a/test/orm/association.py
+++ b/test/orm/association.py
@@ -1,9 +1,10 @@
import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
-
-class AssociationTest(testbase.PersistTest):
+class AssociationTest(PersistTest):
def setUpAll(self):
global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation
metadata = MetaData(testbase.db)
@@ -138,7 +139,7 @@ class AssociationTest(testbase.PersistTest):
sess.flush()
self.assert_(item_keywords.count().scalar() == 0)
-class AssociationTest2(testbase.PersistTest):
+class AssociationTest2(PersistTest):
def setUpAll(self):
global table_originals, table_people, table_isauthor, metadata, Originals, People, IsAuthor
metadata = MetaData(testbase.db)
diff --git a/test/orm/eagertest3.py b/test/orm/assorted_eager.py
index 8e7735812..652186b8e 100644
--- a/test/orm/eagertest3.py
+++ b/test/orm/assorted_eager.py
@@ -1,8 +1,11 @@
-from testbase import PersistTest, AssertMixin
+"""eager loading unittests derived from mailing list-reported problems and trac tickets."""
+
import testbase
+import random, datetime
from sqlalchemy import *
-from sqlalchemy.ext.selectresults import SelectResults
-import random
+from sqlalchemy.orm import *
+from sqlalchemy.ext.sessioncontext import SessionContext
+from testlib import *
class EagerTest(AssertMixin):
def setUpAll(self):
@@ -119,12 +122,12 @@ class EagerTest(AssertMixin):
assert result == [u'1 Some Category', u'3 Some Category']
def test_dslish(self):
- """test the same as witheagerload except building the query via SelectResults"""
+ """test the same as witheagerload except using generative"""
s = create_session()
- q=SelectResults(s.query(Test).options(eagerload('category')))
- l=q.select (
+ q=s.query(Test).options(eagerload('category'))
+ l=q.filter (
and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False))
- ).outerjoin_to('owner_option')
+ ).outerjoin('owner_option')
result = ["%d %s" % ( t.id,t.category.name ) for t in l]
print result
@@ -170,6 +173,7 @@ class EagerTest2(AssertMixin):
def tearDown(self):
for t in metadata.table_iterator(reverse=True):
t.delete().execute()
+
def testeagerterminate(self):
"""test that eager query generation does not include the same mapper's table twice.
@@ -189,7 +193,7 @@ class EagerTest2(AssertMixin):
'right': relation(Right, lazy=False, backref=backref('middle', lazy=False)),
}
)
- session = create_session(bind_to=testbase.db)
+ session = create_session(bind=testbase.db)
p = Middle('test1')
p.left.append(Left('tag1'))
p.right.append(Right('tag2'))
@@ -199,7 +203,7 @@ class EagerTest2(AssertMixin):
obj = session.query(Left).get_by(tag='tag1')
print obj.middle.right[0]
-class EagerTest3(testbase.ORMTest):
+class EagerTest3(ORMTest):
"""test eager loading combined with nested SELECT statements, functions, and aggregates"""
def define_tables(self, metadata):
global datas, foo, stats
@@ -267,7 +271,7 @@ class EagerTest3(testbase.ORMTest):
# algorithms and there are repeated 'somedata' values in the list)
assert verify_result == arb_result
-class EagerTest4(testbase.ORMTest):
+class EagerTest4(ORMTest):
def define_tables(self, metadata):
global departments, employees
departments = Table('departments', metadata,
@@ -315,17 +319,11 @@ class EagerTest4(testbase.ORMTest):
sess.flush()
q = sess.query(Department)
- filters = [q.join_to('employees'),
- Employee.c.name.startswith('J')]
+ q = q.join('employees').filter(Employee.c.name.startswith('J')).distinct().order_by([desc(Department.c.name)])
+ assert q.count() == 2
+ assert q[0] is d2
- d = SelectResults(q)
- d = d.join_to('employees').filter(Employee.c.name.startswith('J'))
- d = d.distinct()
- d = d.order_by([desc(Department.c.name)])
- assert d.count() == 2
- assert d[0] is d2
-
-class EagerTest5(testbase.ORMTest):
+class EagerTest5(ORMTest):
"""test the construction of AliasedClauses for the same eager load property but different
parent mappers, due to inheritance"""
def define_tables(self, metadata):
@@ -416,10 +414,270 @@ class EagerTest5(testbase.ORMTest):
# eager load had to succeed
assert len([c for c in d2.comments]) == 1
-class EagerTest6(testbase.ORMTest):
+class EagerTest6(ORMTest):
+ def define_tables(self, metadata):
+ global designType, design, part, inheritedPart
+ designType = Table('design_types', metadata,
+ Column('design_type_id', Integer, primary_key=True),
+ )
+
+ design =Table('design', metadata,
+ Column('design_id', Integer, primary_key=True),
+ Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+
+ part = Table('parts', metadata,
+ Column('part_id', Integer, primary_key=True),
+ Column('design_id', Integer, ForeignKey('design.design_id')),
+ Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+
+ inheritedPart = Table('inherited_part', metadata,
+ Column('ip_id', Integer, primary_key=True),
+ Column('part_id', Integer, ForeignKey('parts.part_id')),
+ Column('design_id', Integer, ForeignKey('design.design_id')),
+ )
+
+ def testone(self):
+ class Part(object):pass
+ class Design(object):pass
+ class DesignType(object):pass
+ class InheritedPart(object):pass
+
+ mapper(Part, part)
+
+ mapper(InheritedPart, inheritedPart, properties=dict(
+ part=relation(Part, lazy=False)
+ ))
+
+ mapper(Design, design, properties=dict(
+ parts=relation(Part, private=True, backref="design"),
+ inheritedParts=relation(InheritedPart, private=True, backref="design"),
+ ))
+
+ mapper(DesignType, designType, properties=dict(
+ # designs=relation(Design, private=True, backref="type"),
+ ))
+
+ class_mapper(Design).add_property("type", relation(DesignType, lazy=False, backref="designs"))
+ class_mapper(Part).add_property("design", relation(Design, lazy=False, backref="parts"))
+ #Part.mapper.add_property("designType", relation(DesignType))
+
+ d = Design()
+ sess = create_session()
+ sess.save(d)
+ sess.flush()
+ sess.clear()
+ x = sess.query(Design).get(1)
+ x.inheritedParts
+
+class EagerTest7(ORMTest):
+ def define_tables(self, metadata):
+ global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx
+ global Company, Address, Phone, Item,Invoice
+
+ ctx = SessionContext(create_session)
+
+ companies_table = Table('companies', metadata,
+ Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True),
+ Column('company_name', String(40)),
+
+ )
+
+ addresses_table = Table('addresses', metadata,
+ Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
+ Column('company_id', Integer, ForeignKey("companies.company_id")),
+ Column('address', String(40)),
+ )
+
+ phones_table = Table('phone_numbers', metadata,
+ Column('phone_id', Integer, Sequence('phone_id_seq', optional=True), primary_key = True),
+ Column('address_id', Integer, ForeignKey('addresses.address_id')),
+ Column('type', String(20)),
+ Column('number', String(10)),
+ )
+
+ invoice_table = Table('invoices', metadata,
+ Column('invoice_id', Integer, Sequence('invoice_id_seq', optional=True), primary_key = True),
+ Column('company_id', Integer, ForeignKey("companies.company_id")),
+ Column('date', DateTime),
+ )
+
+ items_table = Table('items', metadata,
+ Column('item_id', Integer, Sequence('item_id_seq', optional=True), primary_key = True),
+ Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')),
+ Column('code', String(20)),
+ Column('qty', Integer),
+ )
+
+ class Company(object):
+ def __init__(self):
+ self.company_id = None
+ def __repr__(self):
+ return "Company:" + repr(getattr(self, 'company_id', None)) + " " + repr(getattr(self, 'company_name', None)) + " " + str([repr(addr) for addr in self.addresses])
+
+ class Address(object):
+ def __repr__(self):
+ return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'company_id', None)) + " " + repr(self.address) + str([repr(ph) for ph in getattr(self, 'phones', [])])
+
+ class Phone(object):
+ def __repr__(self):
+ return "Phone: " + repr(getattr(self, 'phone_id', None)) + " " + repr(getattr(self, 'address_id', None)) + " " + repr(self.type) + " " + repr(self.number)
+
+ class Invoice(object):
+ def __init__(self):
+ self.invoice_id = None
+ def __repr__(self):
+ return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None)) + " " + repr(self.company) + " " + str([repr(item) for item in self.items])
+
+ class Item(object):
+ def __repr__(self):
+ return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty)
+
+ def testone(self):
+ """tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated
+ the bug, which is that when the single Company is loaded, no further processing of the rows
+ occurred in order to load the Company's second Address object."""
+
+ mapper(Address, addresses_table, properties={
+ }, extension=ctx.mapper_extension)
+ mapper(Company, companies_table, properties={
+ 'addresses' : relation(Address, lazy=False),
+ }, extension=ctx.mapper_extension)
+ mapper(Invoice, invoice_table, properties={
+ 'company': relation(Company, lazy=False, )
+ }, extension=ctx.mapper_extension)
+
+ c1 = Company()
+ c1.company_name = 'company 1'
+ a1 = Address()
+ a1.address = 'a1 address'
+ c1.addresses.append(a1)
+ a2 = Address()
+ a2.address = 'a2 address'
+ c1.addresses.append(a2)
+ i1 = Invoice()
+ i1.date = datetime.datetime.now()
+ i1.company = c1
+
+ ctx.current.flush()
+
+ company_id = c1.company_id
+ invoice_id = i1.invoice_id
+
+ ctx.current.clear()
+
+ c = ctx.current.query(Company).get(company_id)
+
+ ctx.current.clear()
+
+ i = ctx.current.query(Invoice).get(invoice_id)
+
+ print repr(c)
+ print repr(i.company)
+ self.assert_(repr(c) == repr(i.company))
+
+ def testtwo(self):
+ """this is the original testcase that includes various complicating factors"""
+
+ mapper(Phone, phones_table, extension=ctx.mapper_extension)
+
+ mapper(Address, addresses_table, properties={
+ 'phones': relation(Phone, lazy=False, backref='address')
+ }, extension=ctx.mapper_extension)
+
+ mapper(Company, companies_table, properties={
+ 'addresses' : relation(Address, lazy=False, backref='company'),
+ }, extension=ctx.mapper_extension)
+
+ mapper(Item, items_table, extension=ctx.mapper_extension)
+
+ mapper(Invoice, invoice_table, properties={
+ 'items': relation(Item, lazy=False, backref='invoice'),
+ 'company': relation(Company, lazy=False, backref='invoices')
+ }, extension=ctx.mapper_extension)
+
+ ctx.current.clear()
+ c1 = Company()
+ c1.company_name = 'company 1'
+
+ a1 = Address()
+ a1.address = 'a1 address'
+
+ p1 = Phone()
+ p1.type = 'home'
+ p1.number = '1111'
+
+ a1.phones.append(p1)
+
+ p2 = Phone()
+ p2.type = 'work'
+ p2.number = '22222'
+ a1.phones.append(p2)
+
+ c1.addresses.append(a1)
+
+ a2 = Address()
+ a2.address = 'a2 address'
+
+ p3 = Phone()
+ p3.type = 'home'
+ p3.number = '3333'
+ a2.phones.append(p3)
+
+ p4 = Phone()
+ p4.type = 'work'
+ p4.number = '44444'
+ a2.phones.append(p4)
+
+ c1.addresses.append(a2)
+
+ ctx.current.flush()
+
+ company_id = c1.company_id
+
+ ctx.current.clear()
+
+ a = ctx.current.query(Company).get(company_id)
+ print repr(a)
+
+ # set up an invoice
+ i1 = Invoice()
+ i1.date = datetime.datetime.now()
+ i1.company = c1
+
+ item1 = Item()
+ item1.code = 'aaaa'
+ item1.qty = 1
+ item1.invoice = i1
+
+ item2 = Item()
+ item2.code = 'bbbb'
+ item2.qty = 2
+ item2.invoice = i1
+
+ item3 = Item()
+ item3.code = 'cccc'
+ item3.qty = 3
+ item3.invoice = i1
+
+ ctx.current.flush()
+
+ invoice_id = i1.invoice_id
+
+ ctx.current.clear()
+
+ c = ctx.current.query(Company).get(company_id)
+ print repr(c)
+
+ ctx.current.clear()
+
+ i = ctx.current.query(Invoice).get(invoice_id)
+
+ assert repr(i.company) == repr(c), repr(i.company) + " does not match " + repr(c)
+
+class EagerTest8(ORMTest):
def define_tables(self, metadata):
global project_t, task_t, task_status_t, task_type_t, message_t, message_type_t
-
+
project_t = Table('prj', metadata,
Column('id', Integer, primary_key=True),
Column('created', DateTime , ),
@@ -460,12 +718,12 @@ class EagerTest6(testbase.ORMTest):
testbase.db.execute("INSERT INTO task_status (id) values(1);")
testbase.db.execute("INSERT INTO task_type(id) values(1);")
testbase.db.execute("INSERT INTO task (title, task_type_id, status_id, prj_id) values('task 1',1,1,1);")
-
+
def test_nested_joins(self):
# this is testing some subtle column resolution stuff,
# concerning corresponding_column() being extremely accurate
# as well as how mapper sets up its column properties
-
+
class Task(object):pass
class Task_Type(object):pass
class Message(object):pass
@@ -510,6 +768,7 @@ class EagerTest6(testbase.ORMTest):
for t in session.query(cls.mapper).limit(10).offset(0).list():
print t.id, t.title, t.props_cnt
-
+
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/attributes.py b/test/orm/attributes.py
index 7e0a22aff..9b5f738bf 100644
--- a/test/orm/attributes.py
+++ b/test/orm/attributes.py
@@ -1,10 +1,9 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
+import testbase
+import pickle
import sqlalchemy.orm.attributes as attributes
+from sqlalchemy.orm.collections import collection
from sqlalchemy import exceptions
-import unittest, sys, os
-import pickle
-import testbase
+from testlib import *
class MyTest(object):pass
class MyTest2(object):pass
@@ -50,14 +49,53 @@ class AttributesTest(PersistTest):
# shouldnt be pickling callables at the class level
def somecallable(*args):
return None
- manager.register_attribute(MyTest, 'mt2', uselist = True, trackparent=True, callable_=somecallable)
- x = MyTest()
- x.mt2.append(MyTest2())
-
- x.user_id=7
- s = pickle.dumps(x)
- x2 = pickle.loads(s)
- assert s == pickle.dumps(x2)
+ attr_name = 'mt2'
+ manager.register_attribute(MyTest, attr_name, uselist = True, trackparent=True, callable_=somecallable)
+
+ o = MyTest()
+ o.mt2.append(MyTest2())
+ o.user_id=7
+ o.mt2[0].a = 'abcde'
+ pk_o = pickle.dumps(o)
+
+ o2 = pickle.loads(pk_o)
+
+ # so... pickle is creating a new 'mt2' string after a roundtrip here,
+ # so we'll brute-force set it to be id-equal to the original string
+ o_mt2_str = [ k for k in o.__dict__ if k == 'mt2'][0]
+ o2_mt2_str = [ k for k in o2.__dict__ if k == 'mt2'][0]
+ self.assert_(o_mt2_str == o2_mt2_str)
+ self.assert_(o_mt2_str is not o2_mt2_str)
+ # change the id of o2.__dict__['mt2']
+ former = o2.__dict__['mt2']
+ del o2.__dict__['mt2']
+ o2.__dict__[o_mt2_str] = former
+
+ pk_o2 = pickle.dumps(o2)
+
+ self.assert_(pk_o == pk_o2)
+
+ # the above is kind of distrurbing, so let's do it again a little
+ # differently. the string-id in serialization thing is just an
+ # artifact of pickling that comes up in the first round-trip.
+ # a -> b differs in pickle memoization of 'mt2', but b -> c will
+ # serialize identically.
+
+ o3 = pickle.loads(pk_o2)
+ pk_o3 = pickle.dumps(o3)
+ o4 = pickle.loads(pk_o3)
+ pk_o4 = pickle.dumps(o4)
+
+ self.assert_(pk_o3 == pk_o4)
+
+ # and lastly make sure we still have our data after all that.
+ # identical serialzation is great, *if* it's complete :)
+ self.assert_(o4.user_id == 7)
+ self.assert_(o4.user_name is None)
+ self.assert_(o4.email_address is None)
+ self.assert_(len(o4.mt2) == 1)
+ self.assert_(o4.mt2[0].a == 'abcde')
+ self.assert_(o4.mt2[0].b is None)
def testlist(self):
class User(object):pass
@@ -110,13 +148,12 @@ class AttributesTest(PersistTest):
s = Student()
c = Course()
s.courses.append(c)
- print c.students
- print [s]
self.assert_(c.students == [s])
s.courses.remove(c)
self.assert_(c.students == [])
(s1, s2, s3) = (Student(), Student(), Student())
+
c.students = [s1, s2, s3]
self.assert_(s2.courses == [c])
self.assert_(s1.courses == [c])
@@ -126,9 +163,7 @@ class AttributesTest(PersistTest):
print c
print c.students
s1.courses.remove(c)
- self.assert_(c.students == [s2,s3])
-
-
+ self.assert_(c.students == [s2,s3])
class Post(object):pass
class Blog(object):pass
@@ -334,44 +369,47 @@ class AttributesTest(PersistTest):
manager = attributes.AttributeManager()
class Foo(object):pass
manager.register_attribute(Foo, "collection", uselist=True, typecallable=set)
- assert isinstance(Foo().collection.data, set)
+ assert isinstance(Foo().collection, set)
- manager.register_attribute(Foo, "collection", uselist=True, typecallable=dict)
try:
- Foo().collection
+ manager.register_attribute(Foo, "collection", uselist=True, typecallable=dict)
assert False
except exceptions.ArgumentError, e:
- assert str(e) == "Dictionary collection class 'dict' must implement an append() method"
-
+ assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class"
+
class MyDict(dict):
+ @collection.appender
def append(self, item):
self[item.foo] = item
+ @collection.remover
+ def remove(self, item):
+ del self[item.foo]
manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict)
- assert isinstance(Foo().collection.data, MyDict)
+ assert isinstance(Foo().collection, MyDict)
class MyColl(object):pass
- manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl)
try:
- Foo().collection
+ manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl)
assert False
except exceptions.ArgumentError, e:
- assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no append() or add() method"
+ assert str(e) == "Type MyColl must elect an appender method to be a collection class"
class MyColl(object):
+ @collection.iterator
def __iter__(self):
return iter([])
+ @collection.appender
def append(self, item):
pass
+ @collection.remover
+ def remove(self, item):
+ pass
manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl)
try:
Foo().collection
- assert False
+ assert True
except exceptions.ArgumentError, e:
- assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no clear() method"
-
- def foo(self):pass
- MyColl.clear = foo
- assert isinstance(Foo().collection.data, MyColl)
+ assert False
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/cascade.py b/test/orm/cascade.py
index 16c4db40f..b832c427e 100644
--- a/test/orm/cascade.py
+++ b/test/orm/cascade.py
@@ -1,10 +1,12 @@
-import testbase, tables
-import unittest, sys, datetime
+import testbase
-from sqlalchemy.ext.sessioncontext import SessionContext
from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.ext.sessioncontext import SessionContext
+from testlib import *
+import testlib.tables as tables
-class O2MCascadeTest(testbase.AssertMixin):
+class O2MCascadeTest(AssertMixin):
def tearDown(self):
tables.delete()
@@ -112,7 +114,7 @@ class O2MCascadeTest(testbase.AssertMixin):
sess = create_session()
l = sess.query(tables.User).select()
for u in l:
- self.echo( repr(u.orders))
+ print repr(u.orders)
self.assert_result(l, data[0], *data[1:])
ids = (l[0].user_id, l[2].user_id)
@@ -172,7 +174,7 @@ class O2MCascadeTest(testbase.AssertMixin):
self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0)
-class M2OCascadeTest(testbase.AssertMixin):
+class M2OCascadeTest(AssertMixin):
def tearDown(self):
ctx.current.clear()
for t in metadata.table_iterator(reverse=True):
@@ -260,7 +262,7 @@ class M2OCascadeTest(testbase.AssertMixin):
-class M2MCascadeTest(testbase.AssertMixin):
+class M2MCascadeTest(AssertMixin):
def setUpAll(self):
global metadata, a, b, atob
metadata = MetaData(testbase.db)
@@ -335,7 +337,7 @@ class M2MCascadeTest(testbase.AssertMixin):
assert b.count().scalar() == 0
assert a.count().scalar() == 0
-class UnsavedOrphansTest(testbase.ORMTest):
+class UnsavedOrphansTest(ORMTest):
"""tests regarding pending entities that are orphans"""
def define_tables(self, metadata):
@@ -395,7 +397,7 @@ class UnsavedOrphansTest(testbase.ORMTest):
assert a.address_id is None, "Error: address should not be persistent"
-class UnsavedOrphansTest2(testbase.ORMTest):
+class UnsavedOrphansTest2(ORMTest):
"""same test as UnsavedOrphans only three levels deep"""
def define_tables(self, meta):
@@ -455,7 +457,7 @@ class UnsavedOrphansTest2(testbase.ORMTest):
assert item.id is None
assert attr.id is None
-class DoubleParentOrphanTest(testbase.AssertMixin):
+class DoubleParentOrphanTest(AssertMixin):
"""test orphan detection for an entity with two parent relations"""
def setUpAll(self):
@@ -521,7 +523,7 @@ class DoubleParentOrphanTest(testbase.AssertMixin):
assert False
except exceptions.FlushError, e:
assert True
-
-
+
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/collection.py b/test/orm/collection.py
new file mode 100644
index 000000000..1f4f64928
--- /dev/null
+++ b/test/orm/collection.py
@@ -0,0 +1,1140 @@
+import testbase
+from sqlalchemy import *
+import sqlalchemy.exceptions as exceptions
+from sqlalchemy.orm import create_session, mapper, relation, \
+ interfaces, attributes
+import sqlalchemy.orm.collections as collections
+from sqlalchemy.orm.collections import collection
+from sqlalchemy import util
+from operator import and_
+from testlib import *
+
+class Canary(interfaces.AttributeExtension):
+ def __init__(self):
+ self.data = set()
+ self.added = set()
+ self.removed = set()
+ def append(self, obj, value, initiator):
+ assert value not in self.added
+ self.data.add(value)
+ self.added.add(value)
+ def remove(self, obj, value, initiator):
+ assert value not in self.removed
+ self.data.remove(value)
+ self.removed.add(value)
+ def set(self, obj, value, oldvalue, initiator):
+ if oldvalue is not None:
+ self.remove(obj, oldvalue, None)
+ self.append(obj, value, None)
+
+class Entity(object):
+ def __init__(self, a=None, b=None, c=None):
+ self.a = a
+ self.b = b
+ self.c = c
+ def __repr__(self):
+ return str((id(self), self.a, self.b, self.c))
+
+manager = attributes.AttributeManager()
+
+_id = 1
+def entity_maker():
+ global _id
+ _id += 1
+ return Entity(_id)
+def dictable_entity(a=None, b=None, c=None):
+ global _id
+ _id += 1
+ return Entity(a or str(_id), b or 'value %s' % _id, c)
+
+
+class CollectionsTest(PersistTest):
+ def _test_adapter(self, typecallable, creator=entity_maker,
+ to_set=None):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ if to_set is None:
+ to_set = lambda col: set(col)
+
+ def assert_eq():
+ self.assert_(to_set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ assert_ne = lambda: self.assert_(to_set(direct) != canary.data)
+
+ e1, e2 = creator(), creator()
+
+ adapter.append_with_event(e1)
+ assert_eq()
+
+ adapter.append_without_event(e2)
+ assert_ne()
+ canary.data.add(e2)
+ assert_eq()
+
+ adapter.remove_without_event(e2)
+ assert_ne()
+ canary.data.remove(e2)
+ assert_eq()
+
+ adapter.remove_with_event(e1)
+ assert_eq()
+
+ def _test_list(self, typecallable, creator=entity_maker):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = list()
+
+ def assert_eq():
+ self.assert_(set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(direct == control)
+
+ # assume append() is available for list tests
+ e = creator()
+ direct.append(e)
+ control.append(e)
+ assert_eq()
+
+ if hasattr(direct, 'pop'):
+ direct.pop()
+ control.pop()
+ assert_eq()
+
+ if hasattr(direct, '__setitem__'):
+ e = creator()
+ direct.append(e)
+ control.append(e)
+
+ e = creator()
+ direct[0] = e
+ control[0] = e
+ assert_eq()
+
+ if reduce(and_, [hasattr(direct, a) for a in
+ ('__delitem', 'insert', '__len__')], True):
+ values = [creator(), creator(), creator(), creator()]
+ direct[slice(0,1)] = values
+ control[slice(0,1)] = values
+ assert_eq()
+
+ values = [creator(), creator()]
+ direct[slice(0,-1,2)] = values
+ control[slice(0,-1,2)] = values
+ assert_eq()
+
+ values = [creator()]
+ direct[slice(0,-1)] = values
+ control[slice(0,-1)] = values
+ assert_eq()
+
+ if hasattr(direct, '__delitem__'):
+ e = creator()
+ direct.append(e)
+ control.append(e)
+ del direct[-1]
+ del control[-1]
+ assert_eq()
+
+ if hasattr(direct, '__getslice__'):
+ for e in [creator(), creator(), creator(), creator()]:
+ direct.append(e)
+ control.append(e)
+
+ del direct[:-3]
+ del control[:-3]
+ assert_eq()
+
+ del direct[0:1]
+ del control[0:1]
+ assert_eq()
+
+ del direct[::2]
+ del control[::2]
+ assert_eq()
+
+ if hasattr(direct, 'remove'):
+ e = creator()
+ direct.append(e)
+ control.append(e)
+
+ direct.remove(e)
+ control.remove(e)
+ assert_eq()
+
+ if hasattr(direct, '__setslice__'):
+ values = [creator(), creator()]
+ direct[0:1] = values
+ control[0:1] = values
+ assert_eq()
+
+ values = [creator()]
+ direct[0:] = values
+ control[0:] = values
+ assert_eq()
+
+ if hasattr(direct, '__delslice__'):
+ for i in range(1, 4):
+ e = creator()
+ direct.append(e)
+ control.append(e)
+
+ del direct[-1:]
+ del control[-1:]
+ assert_eq()
+
+ del direct[1:2]
+ del control[1:2]
+ assert_eq()
+
+ del direct[:]
+ del control[:]
+ assert_eq()
+
+ if hasattr(direct, 'extend'):
+ values = [creator(), creator(), creator()]
+
+ direct.extend(values)
+ control.extend(values)
+ assert_eq()
+
+ def _test_list_bulk(self, typecallable, creator=entity_maker):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ direct = obj.attr
+
+ e1 = creator()
+ obj.attr.append(e1)
+
+ like_me = typecallable()
+ e2 = creator()
+ like_me.append(e2)
+
+ self.assert_(obj.attr is direct)
+ obj.attr = like_me
+ self.assert_(obj.attr is not direct)
+ self.assert_(obj.attr is not like_me)
+ self.assert_(set(obj.attr) == set([e2]))
+ self.assert_(e1 in canary.removed)
+ self.assert_(e2 in canary.added)
+
+ e3 = creator()
+ real_list = [e3]
+ obj.attr = real_list
+ self.assert_(obj.attr is not real_list)
+ self.assert_(set(obj.attr) == set([e3]))
+ self.assert_(e2 in canary.removed)
+ self.assert_(e3 in canary.added)
+
+ e4 = creator()
+ try:
+ obj.attr = set([e4])
+ self.assert_(False)
+ except exceptions.ArgumentError:
+ self.assert_(e4 not in canary.data)
+ self.assert_(e3 in canary.data)
+
+ def test_list(self):
+ self._test_adapter(list)
+ self._test_list(list)
+ self._test_list_bulk(list)
+
+ def test_list_subclass(self):
+ class MyList(list):
+ pass
+ self._test_adapter(MyList)
+ self._test_list(MyList)
+ self._test_list_bulk(MyList)
+ self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList))
+
+ def test_list_duck(self):
+ class ListLike(object):
+ def __init__(self):
+ self.data = list()
+ def append(self, item):
+ self.data.append(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def insert(self, index, item):
+ self.data.insert(index, item)
+ def pop(self, index=-1):
+ return self.data.pop(index)
+ def extend(self):
+ assert False
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'ListLike(%s)' % repr(self.data)
+
+ self._test_adapter(ListLike)
+ self._test_list(ListLike)
+ self._test_list_bulk(ListLike)
+ self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike))
+
+ def test_list_emulates(self):
+ class ListIsh(object):
+ __emulates__ = list
+ def __init__(self):
+ self.data = list()
+ def append(self, item):
+ self.data.append(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def insert(self, index, item):
+ self.data.insert(index, item)
+ def pop(self, index=-1):
+ return self.data.pop(index)
+ def extend(self):
+ assert False
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'ListIsh(%s)' % repr(self.data)
+
+ self._test_adapter(ListIsh)
+ self._test_list(ListIsh)
+ self._test_list_bulk(ListIsh)
+ self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh))
+
+ def _test_set(self, typecallable, creator=entity_maker):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = set()
+
+ def assert_eq():
+ self.assert_(set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(direct == control)
+
+ def addall(*values):
+ for item in values:
+ direct.add(item)
+ control.add(item)
+ assert_eq()
+ def zap():
+ for item in list(direct):
+ direct.remove(item)
+ control.clear()
+
+ # assume add() is available for list tests
+ addall(creator())
+
+ if hasattr(direct, 'pop'):
+ direct.pop()
+ control.pop()
+ assert_eq()
+
+ if hasattr(direct, 'remove'):
+ e = creator()
+ addall(e)
+
+ direct.remove(e)
+ control.remove(e)
+ assert_eq()
+
+ e = creator()
+ try:
+ direct.remove(e)
+ except KeyError:
+ assert_eq()
+ self.assert_(e not in canary.removed)
+ else:
+ self.assert_(False)
+
+ if hasattr(direct, 'discard'):
+ e = creator()
+ addall(e)
+
+ direct.discard(e)
+ control.discard(e)
+ assert_eq()
+
+ e = creator()
+ direct.discard(e)
+ self.assert_(e not in canary.removed)
+ assert_eq()
+
+ if hasattr(direct, 'update'):
+ e = creator()
+ addall(e)
+
+ values = set([e, creator(), creator()])
+
+ direct.update(values)
+ control.update(values)
+ assert_eq()
+
+ if hasattr(direct, 'clear'):
+ addall(creator(), creator())
+ direct.clear()
+ control.clear()
+ assert_eq()
+
+ if hasattr(direct, 'difference_update'):
+ zap()
+ addall(creator(), creator())
+ values = set([creator()])
+
+ direct.difference_update(values)
+ control.difference_update(values)
+ assert_eq()
+ values.update(set([e, creator()]))
+ direct.difference_update(values)
+ control.difference_update(values)
+ assert_eq()
+
+ if hasattr(direct, 'intersection_update'):
+ zap()
+ e = creator()
+ addall(e, creator(), creator())
+ values = set(control)
+
+ direct.intersection_update(values)
+ control.intersection_update(values)
+ assert_eq()
+
+ values.update(set([e, creator()]))
+ direct.intersection_update(values)
+ control.intersection_update(values)
+ assert_eq()
+
+ if hasattr(direct, 'symmetric_difference_update'):
+ zap()
+ e = creator()
+ addall(e, creator(), creator())
+
+ values = set([e, creator()])
+ direct.symmetric_difference_update(values)
+ control.symmetric_difference_update(values)
+ assert_eq()
+
+ e = creator()
+ addall(e)
+ values = set([e])
+ direct.symmetric_difference_update(values)
+ control.symmetric_difference_update(values)
+ assert_eq()
+
+ values = set()
+ direct.symmetric_difference_update(values)
+ control.symmetric_difference_update(values)
+ assert_eq()
+
+ def _test_set_bulk(self, typecallable, creator=entity_maker):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ direct = obj.attr
+
+ e1 = creator()
+ obj.attr.add(e1)
+
+ like_me = typecallable()
+ e2 = creator()
+ like_me.add(e2)
+
+ self.assert_(obj.attr is direct)
+ obj.attr = like_me
+ self.assert_(obj.attr is not direct)
+ self.assert_(obj.attr is not like_me)
+ self.assert_(obj.attr == set([e2]))
+ self.assert_(e1 in canary.removed)
+ self.assert_(e2 in canary.added)
+
+ e3 = creator()
+ real_set = set([e3])
+ obj.attr = real_set
+ self.assert_(obj.attr is not real_set)
+ self.assert_(obj.attr == set([e3]))
+ self.assert_(e2 in canary.removed)
+ self.assert_(e3 in canary.added)
+
+ e4 = creator()
+ try:
+ obj.attr = [e4]
+ self.assert_(False)
+ except exceptions.ArgumentError:
+ self.assert_(e4 not in canary.data)
+ self.assert_(e3 in canary.data)
+
+ def test_set(self):
+ self._test_adapter(set)
+ self._test_set(set)
+ self._test_set_bulk(set)
+
+ def test_set_subclass(self):
+ class MySet(set):
+ pass
+ self._test_adapter(MySet)
+ self._test_set(MySet)
+ self._test_set_bulk(MySet)
+ self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet))
+
+ def test_set_duck(self):
+ class SetLike(object):
+ def __init__(self):
+ self.data = set()
+ def add(self, item):
+ self.data.add(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def discard(self, item):
+ self.data.discard(item)
+ def pop(self):
+ return self.data.pop()
+ def update(self, other):
+ self.data.update(other)
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+
+ self._test_adapter(SetLike)
+ self._test_set(SetLike)
+ self._test_set_bulk(SetLike)
+ self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike))
+
+ def test_set_emulates(self):
+ class SetIsh(object):
+ __emulates__ = set
+ def __init__(self):
+ self.data = set()
+ def add(self, item):
+ self.data.add(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def discard(self, item):
+ self.data.discard(item)
+ def pop(self):
+ return self.data.pop()
+ def update(self, other):
+ self.data.update(other)
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+
+ self._test_adapter(SetIsh)
+ self._test_set(SetIsh)
+ self._test_set_bulk(SetIsh)
+ self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh))
+
+ def _test_dict(self, typecallable, creator=dictable_entity):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = dict()
+
+ def assert_eq():
+ self.assert_(set(direct.values()) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(direct == control)
+
+ def addall(*values):
+ for item in values:
+ direct.set(item)
+ control[item.a] = item
+ assert_eq()
+ def zap():
+ for item in list(adapter):
+ direct.remove(item)
+ control.clear()
+
+ # assume an 'set' method is available for tests
+ addall(creator())
+
+ if hasattr(direct, '__setitem__'):
+ e = creator()
+ direct[e.a] = e
+ control[e.a] = e
+ assert_eq()
+
+ e = creator(e.a, e.b)
+ direct[e.a] = e
+ control[e.a] = e
+ assert_eq()
+
+ if hasattr(direct, '__delitem__'):
+ e = creator()
+ addall(e)
+
+ del direct[e.a]
+ del control[e.a]
+ assert_eq()
+
+ e = creator()
+ try:
+ del direct[e.a]
+ except KeyError:
+ self.assert_(e not in canary.removed)
+
+ if hasattr(direct, 'clear'):
+ addall(creator(), creator(), creator())
+
+ direct.clear()
+ control.clear()
+ assert_eq()
+
+ direct.clear()
+ control.clear()
+ assert_eq()
+
+ if hasattr(direct, 'pop'):
+ e = creator()
+ addall(e)
+
+ direct.pop(e.a)
+ control.pop(e.a)
+ assert_eq()
+
+ e = creator()
+ try:
+ direct.pop(e.a)
+ except KeyError:
+ self.assert_(e not in canary.removed)
+
+ if hasattr(direct, 'popitem'):
+ zap()
+ e = creator()
+ addall(e)
+
+ direct.popitem()
+ control.popitem()
+ assert_eq()
+
+ if hasattr(direct, 'setdefault'):
+ e = creator()
+
+ val_a = direct.setdefault(e.a, e)
+ val_b = control.setdefault(e.a, e)
+ assert_eq()
+ self.assert_(val_a is val_b)
+
+ val_a = direct.setdefault(e.a, e)
+ val_b = control.setdefault(e.a, e)
+ assert_eq()
+ self.assert_(val_a is val_b)
+
+ if hasattr(direct, 'update'):
+ e = creator()
+ d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
+ addall(e, creator())
+
+ direct.update(d)
+ control.update(d)
+ assert_eq()
+
+ kw = dict([(ee.a, ee) for ee in [e, creator()]])
+ direct.update(**kw)
+ control.update(**kw)
+ assert_eq()
+
+ def _test_dict_bulk(self, typecallable, creator=dictable_entity):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ direct = obj.attr
+
+ e1 = creator()
+ collections.collection_adapter(direct).append_with_event(e1)
+
+ like_me = typecallable()
+ e2 = creator()
+ like_me.set(e2)
+
+ self.assert_(obj.attr is direct)
+ obj.attr = like_me
+ self.assert_(obj.attr is not direct)
+ self.assert_(obj.attr is not like_me)
+ self.assert_(set(collections.collection_adapter(obj.attr)) == set([e2]))
+ self.assert_(e1 in canary.removed)
+ self.assert_(e2 in canary.added)
+
+ e3 = creator()
+ real_dict = dict(keyignored1=e3)
+ obj.attr = real_dict
+ self.assert_(obj.attr is not real_dict)
+ self.assert_('keyignored1' not in obj.attr)
+ self.assert_(set(collections.collection_adapter(obj.attr)) == set([e3]))
+ self.assert_(e2 in canary.removed)
+ self.assert_(e3 in canary.added)
+
+ e4 = creator()
+ try:
+ obj.attr = [e4]
+ self.assert_(False)
+ except exceptions.ArgumentError:
+ self.assert_(e4 not in canary.data)
+ self.assert_(e3 in canary.data)
+
+ def test_dict(self):
+ try:
+ self._test_adapter(dict, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self.assert_(False)
+ except exceptions.ArgumentError, e:
+ self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
+
+ try:
+ self._test_dict(dict)
+ self.assert_(False)
+ except exceptions.ArgumentError, e:
+ self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
+
+ def test_dict_subclass(self):
+ class MyDict(dict):
+ @collection.appender
+ @collection.internally_instrumented
+ def set(self, item, _sa_initiator=None):
+ self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
+ @collection.remover
+ @collection.internally_instrumented
+ def _remove(self, item, _sa_initiator=None):
+ self.__delitem__(item.a, _sa_initiator=_sa_initiator)
+
+ self._test_adapter(MyDict, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self._test_dict(MyDict)
+ self._test_dict_bulk(MyDict)
+ self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict))
+
+ def test_dict_subclass2(self):
+ class MyEasyDict(collections.MappedCollection):
+ def __init__(self):
+ super(MyEasyDict, self).__init__(lambda e: e.a)
+
+ self._test_adapter(MyEasyDict, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self._test_dict(MyEasyDict)
+ self._test_dict_bulk(MyEasyDict)
+ self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict))
+
+ def test_dict_subclass3(self):
+ class MyOrdered(util.OrderedDict, collections.MappedCollection):
+ def __init__(self):
+ collections.MappedCollection.__init__(self, lambda e: e.a)
+ util.OrderedDict.__init__(self)
+
+ self._test_adapter(MyOrdered, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self._test_dict(MyOrdered)
+ self._test_dict_bulk(MyOrdered)
+ self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered))
+
+ def test_dict_duck(self):
+ class DictLike(object):
+ def __init__(self):
+ self.data = dict()
+
+ @collection.appender
+ @collection.replaces(1)
+ def set(self, item):
+ current = self.data.get(item.a, None)
+ self.data[item.a] = item
+ return current
+ @collection.remover
+ def _remove(self, item):
+ del self.data[item.a]
+ def __setitem__(self, key, value):
+ self.data[key] = value
+ def __getitem__(self, key):
+ return self.data[key]
+ def __delitem__(self, key):
+ del self.data[key]
+ def values(self):
+ return self.data.values()
+ def __contains__(self, key):
+ return key in self.data
+ @collection.iterator
+ def itervalues(self):
+ return self.data.itervalues()
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'DictLike(%s)' % repr(self.data)
+
+ self._test_adapter(DictLike, dictable_entity,
+ to_set=lambda c: set(c.itervalues()))
+ self._test_dict(DictLike)
+ self._test_dict_bulk(DictLike)
+ self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike))
+
+ def test_dict_emulates(self):
+ class DictIsh(object):
+ __emulates__ = dict
+ def __init__(self):
+ self.data = dict()
+
+ @collection.appender
+ @collection.replaces(1)
+ def set(self, item):
+ current = self.data.get(item.a, None)
+ self.data[item.a] = item
+ return current
+ @collection.remover
+ def _remove(self, item):
+ del self.data[item.a]
+ def __setitem__(self, key, value):
+ self.data[key] = value
+ def __getitem__(self, key):
+ return self.data[key]
+ def __delitem__(self, key):
+ del self.data[key]
+ def values(self):
+ return self.data.values()
+ def __contains__(self, key):
+ return key in self.data
+ @collection.iterator
+ def itervalues(self):
+ return self.data.itervalues()
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'DictIsh(%s)' % repr(self.data)
+
+ self._test_adapter(DictIsh, dictable_entity,
+ to_set=lambda c: set(c.itervalues()))
+ self._test_dict(DictIsh)
+ self._test_dict_bulk(DictIsh)
+ self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh))
+
+ def _test_object(self, typecallable, creator=entity_maker):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = set()
+
+ def assert_eq():
+ self.assert_(set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(direct == control)
+
+ # There is no API for object collections. We'll make one up
+ # for the purposes of the test.
+ e = creator()
+ direct.push(e)
+ control.add(e)
+ assert_eq()
+
+ direct.zark(e)
+ control.remove(e)
+ assert_eq()
+
+ e = creator()
+ direct.maybe_zark(e)
+ control.discard(e)
+ assert_eq()
+
+ e = creator()
+ direct.push(e)
+ control.add(e)
+ assert_eq()
+
+ e = creator()
+ direct.maybe_zark(e)
+ control.discard(e)
+ assert_eq()
+
+ def test_object_duck(self):
+ class MyCollection(object):
+ def __init__(self):
+ self.data = set()
+ @collection.appender
+ def push(self, item):
+ self.data.add(item)
+ @collection.remover
+ def zark(self, item):
+ self.data.remove(item)
+ @collection.removes_return()
+ def maybe_zark(self, item):
+ if item in self.data:
+ self.data.remove(item)
+ return item
+ @collection.iterator
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+
+ self._test_adapter(MyCollection)
+ self._test_object(MyCollection)
+ self.assert_(getattr(MyCollection, '_sa_instrumented') ==
+ id(MyCollection))
+
+ def test_object_emulates(self):
+ class MyCollection2(object):
+ __emulates__ = None
+ def __init__(self):
+ self.data = set()
+ # looks like a list
+ def append(self, item):
+ assert False
+ @collection.appender
+ def push(self, item):
+ self.data.add(item)
+ @collection.remover
+ def zark(self, item):
+ self.data.remove(item)
+ @collection.removes_return()
+ def maybe_zark(self, item):
+ if item in self.data:
+ self.data.remove(item)
+ return item
+ @collection.iterator
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+
+ self._test_adapter(MyCollection2)
+ self._test_object(MyCollection2)
+ self.assert_(getattr(MyCollection2, '_sa_instrumented') ==
+ id(MyCollection2))
+
+ def test_lifecycle(self):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ creator = entity_maker
+ manager.register_attribute(Foo, 'attr', True, extension=canary)
+
+ obj = Foo()
+ col1 = obj.attr
+
+ e1 = creator()
+ obj.attr.append(e1)
+
+ e2 = creator()
+ bulk1 = [e2]
+ # empty & sever col1 from obj
+ obj.attr = bulk1
+ self.assert_(len(col1) == 0)
+ self.assert_(len(canary.data) == 1)
+ self.assert_(obj.attr is not col1)
+ self.assert_(obj.attr is not bulk1)
+ self.assert_(obj.attr == bulk1)
+
+ e3 = creator()
+ col1.append(e3)
+ self.assert_(e3 not in canary.data)
+ self.assert_(collections.collection_adapter(col1) is None)
+
+ obj.attr[0] = e3
+ self.assert_(e3 in canary.data)
+
+class DictHelpersTest(ORMTest):
+ def define_tables(self, metadata):
+ global parents, children, Parent, Child
+
+ parents = Table('parents', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('label', String))
+ children = Table('children', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('parents.id'),
+ nullable=False),
+ Column('a', String),
+ Column('b', String),
+ Column('c', String))
+
+ class Parent(object):
+ def __init__(self, label=None):
+ self.label = label
+ class Child(object):
+ def __init__(self, a=None, b=None, c=None):
+ self.a = a
+ self.b = b
+ self.c = c
+
+ def _test_scalar_mapped(self, collection_class):
+ mapper(Child, children)
+ mapper(Parent, parents, properties={
+ 'children': relation(Child, collection_class=collection_class,
+ cascade="all, delete-orphan")
+ })
+
+ p = Parent()
+ p.children['foo'] = Child('foo', 'value')
+ p.children['bar'] = Child('bar', 'value')
+ session = create_session()
+ session.save(p)
+ session.flush()
+ pid = p.id
+ session.clear()
+
+ p = session.query(Parent).get(pid)
+
+ self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
+ cid = p.children['foo'].id
+
+ collections.collection_adapter(p.children).append_with_event(
+ Child('foo', 'newvalue'))
+
+ session.save(p)
+ session.flush()
+ session.clear()
+
+ p = session.query(Parent).get(pid)
+
+ self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
+ self.assert_(p.children['foo'].id != cid)
+
+ self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
+ session.flush()
+ session.clear()
+
+ p = session.query(Parent).get(pid)
+ self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
+
+ collections.collection_adapter(p.children).remove_with_event(
+ p.children['foo'])
+
+ self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
+ session.flush()
+ session.clear()
+
+ p = session.query(Parent).get(pid)
+ self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
+
+ del p.children['bar']
+ self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
+ session.flush()
+ session.clear()
+
+ p = session.query(Parent).get(pid)
+ self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
+
+
+ def _test_composite_mapped(self, collection_class):
+ mapper(Child, children)
+ mapper(Parent, parents, properties={
+ 'children': relation(Child, collection_class=collection_class,
+ cascade="all, delete-orphan")
+ })
+
+ p = Parent()
+ p.children[('foo', '1')] = Child('foo', '1', 'value 1')
+ p.children[('foo', '2')] = Child('foo', '2', 'value 2')
+
+ session = create_session()
+ session.save(p)
+ session.flush()
+ pid = p.id
+ session.clear()
+
+ p = session.query(Parent).get(pid)
+
+ self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
+ cid = p.children[('foo', '1')].id
+
+ collections.collection_adapter(p.children).append_with_event(
+ Child('foo', '1', 'newvalue'))
+
+ session.save(p)
+ session.flush()
+ session.clear()
+
+ p = session.query(Parent).get(pid)
+
+ self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
+ self.assert_(p.children[('foo', '1')].id != cid)
+
+ self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
+
+ def test_mapped_collection(self):
+ collection_class = collections.mapped_collection(lambda c: c.a)
+ self._test_scalar_mapped(collection_class)
+
+ def test_mapped_collection2(self):
+ collection_class = collections.mapped_collection(lambda c: (c.a, c.b))
+ self._test_composite_mapped(collection_class)
+
+ def test_attr_mapped_collection(self):
+ collection_class = collections.attribute_mapped_collection('a')
+ self._test_scalar_mapped(collection_class)
+
+ def test_column_mapped_collection(self):
+ collection_class = collections.column_mapped_collection(children.c.a)
+ self._test_scalar_mapped(collection_class)
+
+ def test_column_mapped_collection2(self):
+ collection_class = collections.column_mapped_collection((children.c.a,
+ children.c.b))
+ self._test_composite_mapped(collection_class)
+
+ def test_mixin(self):
+ class Ordered(util.OrderedDict, collections.MappedCollection):
+ def __init__(self):
+ collections.MappedCollection.__init__(self, lambda v: v.a)
+ util.OrderedDict.__init__(self)
+ collection_class = Ordered
+ self._test_scalar_mapped(collection_class)
+
+ def test_mixin2(self):
+ class Ordered2(util.OrderedDict, collections.MappedCollection):
+ def __init__(self, keyfunc):
+ collections.MappedCollection.__init__(self, keyfunc)
+ util.OrderedDict.__init__(self)
+ collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
+ self._test_composite_mapped(collection_class)
+
+if __name__ == "__main__":
+ testbase.main()
diff --git a/test/orm/compile.py b/test/orm/compile.py
index 61107ce8e..23f04db85 100644
--- a/test/orm/compile.py
+++ b/test/orm/compile.py
@@ -1,7 +1,10 @@
import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
-class CompileTest(testbase.AssertMixin):
+
+class CompileTest(AssertMixin):
"""test various mapper compilation scenarios"""
def tearDownAll(self):
clear_mappers()
diff --git a/test/orm/cycles.py b/test/orm/cycles.py
index c53e9e846..ce3065f77 100644
--- a/test/orm/cycles.py
+++ b/test/orm/cycles.py
@@ -1,11 +1,8 @@
-from testbase import PersistTest, AssertMixin, ORMTest
-import unittest, sys, os
-from sqlalchemy import *
-import StringIO
import testbase
-
-from tables import *
-import tables
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.tables import *
"""test cyclical mapper relationships. Many of the assertions are provided
via running with postgres, which is strict about foreign keys.
@@ -107,22 +104,6 @@ class SelfReferentialTest(AssertMixin):
sess.delete(a)
sess.flush()
- def testeagerassertion(self):
- """test that an eager self-referential relationship raises an error."""
- class C1(Tester):
- pass
- class C2(Tester):
- pass
-
- m1 = mapper(C1, t1, properties = {
- 'c1s' : relation(C1, lazy=False),
- })
-
- try:
- m1.compile()
- assert False
- except exceptions.ArgumentError:
- assert True
class SelfReferentialNoPKTest(AssertMixin):
"""test self-referential relationship that joins on a column other than the primary key column"""
@@ -541,8 +522,6 @@ class OneToManyManyToOneTest(AssertMixin):
)
)
- print str(Person.mapper.props['balls'].primaryjoin)
-
b = Ball('some data')
p = Person('some data')
p.balls.append(b)
@@ -554,7 +533,7 @@ class OneToManyManyToOneTest(AssertMixin):
sess.save(b)
sess.save(p)
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
(
"INSERT INTO person (favorite_ball_id, data) VALUES (:favorite_ball_id, :data)",
{'favorite_ball_id': None, 'data':'some data'}
@@ -608,7 +587,7 @@ class OneToManyManyToOneTest(AssertMixin):
)
])
sess.delete(p)
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
# heres the post update (which is a pre-update with deletes)
(
"UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id",
@@ -645,8 +624,6 @@ class OneToManyManyToOneTest(AssertMixin):
)
)
- print str(Person.mapper.props['balls'].primaryjoin)
-
b = Ball('some data')
p = Person('some data')
p.balls.append(b)
@@ -660,7 +637,7 @@ class OneToManyManyToOneTest(AssertMixin):
sess = create_session()
[sess.save(x) for x in [b,p,b2,b3,b4]]
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
(
"INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
{'person_id':None, 'data':'some data'}
@@ -739,7 +716,7 @@ class OneToManyManyToOneTest(AssertMixin):
])
sess.delete(p)
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
lambda ctx:{'person_id': None, 'ball_id': b.id}
@@ -851,7 +828,7 @@ class SelfReferentialPostUpdateTest(AssertMixin):
remove_child(root, cats)
# pre-trigger lazy loader on 'cats' to make the test easier
cats.children
- self.assert_sql(db, lambda: session.flush(), [
+ self.assert_sql(testbase.db, lambda: session.flush(), [
(
"UPDATE node SET prev_sibling_id=:prev_sibling_id WHERE node.id = :node_id",
lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}
diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py
index 37b5ecdf7..a109be56f 100644
--- a/test/orm/eager_relations.py
+++ b/test/orm/eager_relations.py
@@ -1,9 +1,9 @@
"""basic tests of eager loaded attributes"""
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-import testbase
-
+from testlib import *
from fixtures import *
from query import QueryTest
@@ -20,7 +20,7 @@ class EagerTest(QueryTest):
sess = create_session()
q = sess.query(User)
- assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all()
+ assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all()
assert fixtures.user_address_result == q.all()
def test_no_orphan(self):
@@ -66,7 +66,7 @@ class EagerTest(QueryTest):
))
q = create_session().query(User)
- l = q.filter(users.c.id==addresses.c.user_id).order_by(addresses.c.email_address).all()
+ l = q.filter(User.id==Address.user_id).order_by(Address.email_address).all()
assert [
User(id=8, addresses=[
@@ -148,7 +148,7 @@ class EagerTest(QueryTest):
assert fixtures.user_address_result == sess.query(User).all()
def test_double(self):
- """tests lazy loading with two relations simulatneously, from the same table, using aliases. """
+ """tests eager loading with two relations simulatneously, from the same table, using aliases. """
openorders = alias(orders, 'openorders')
closedorders = alias(orders, 'closedorders')
@@ -185,6 +185,46 @@ class EagerTest(QueryTest):
] == q.all()
self.assert_sql_count(testbase.db, go, 1)
+
+ def test_double_same_mappers(self):
+ """tests eager loading with two relations simulatneously, from the same table, using aliases. """
+
+ mapper(Address, addresses)
+ mapper(Order, orders, properties={
+ 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id),
+ })
+ mapper(Item, items)
+ mapper(User, users, properties = dict(
+ addresses = relation(Address, lazy=False),
+ open_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 1, users.c.id==orders.c.user_id), lazy=False),
+ closed_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 0, users.c.id==orders.c.user_id), lazy=False)
+ ))
+ q = create_session().query(User)
+
+ def go():
+ assert [
+ User(
+ id=7,
+ addresses=[Address(id=1)],
+ open_orders = [Order(id=3, items=[Item(id=3), Item(id=4), Item(id=5)])],
+ closed_orders = [Order(id=1, items=[Item(id=1), Item(id=2), Item(id=3)]), Order(id=5, items=[Item(id=5)])]
+ ),
+ User(
+ id=8,
+ addresses=[Address(id=2), Address(id=3), Address(id=4)],
+ open_orders = [],
+ closed_orders = []
+ ),
+ User(
+ id=9,
+ addresses=[Address(id=5)],
+ open_orders = [Order(id=4, items=[Item(id=1), Item(id=5)])],
+ closed_orders = [Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])]
+ ),
+ User(id=10)
+
+ ] == q.all()
+ self.assert_sql_count(testbase.db, go, 1)
def test_limit(self):
"""test limit operations combined with lazy-load relationships."""
@@ -236,7 +276,7 @@ class EagerTest(QueryTest):
sess = create_session()
q = sess.query(Item)
l = q.filter((Item.c.description=='item 2') | (Item.c.description=='item 5') | (Item.c.description=='item 3')).\
- order_by(Item.c.id).limit(2).all()
+ order_by(Item.id).limit(2).all()
assert fixtures.item_keyword_result[1:3] == l
@@ -259,7 +299,7 @@ class EagerTest(QueryTest):
q = sess.query(User)
if testbase.db.engine.name != 'mssql':
- l = q.join('orders').order_by(desc(orders.c.user_id)).limit(2).offset(1)
+ l = q.join('orders').order_by(desc(Order.user_id)).limit(2).offset(1)
assert [
User(id=9,
orders=[Order(id=2), Order(id=4)],
@@ -271,7 +311,7 @@ class EagerTest(QueryTest):
)
] == l.all()
- l = q.join('addresses').order_by(desc(addresses.c.email_address)).limit(1).offset(0)
+ l = q.join('addresses').order_by(desc(Address.email_address)).limit(1).offset(0)
assert [
User(id=7,
orders=[Order(id=1), Order(id=3), Order(id=5)],
@@ -375,6 +415,7 @@ class EagerTest(QueryTest):
'user':relation(User, lazy=False)
})
mapper(User, users)
+ mapper(Item, items)
q = create_session().query(Order)
assert [
@@ -382,7 +423,7 @@ class EagerTest(QueryTest):
Order(id=4, user=User(id=9))
] == q.all()
- q = q.select_from(s.join(order_items).join(items)).filter(~items.c.id.in_(1, 2, 5))
+ q = q.select_from(s.join(order_items).join(items)).filter(~Item.id.in_(1, 2, 5))
assert [
Order(id=3, user=User(id=7)),
] == q.all()
@@ -394,8 +435,80 @@ class EagerTest(QueryTest):
addresses = relation(mapper(Address, addresses), lazy=False)
))
q = create_session().query(User)
- l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(addresses.c.user_id==users.c.id)
+ l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id)
assert fixtures.user_address_result[1:2] == l.all()
+class SelfReferentialEagerTest(ORMTest):
+ def define_tables(self, metadata):
+ global nodes
+ nodes = Table('nodes', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('nodes.id')),
+ Column('data', String(30)))
+
+ def test_basic(self):
+ class Node(Base):
+ def append(self, node):
+ self.children.append(node)
+
+ mapper(Node, nodes, properties={
+ 'children':relation(Node, lazy=False, join_depth=3)
+ })
+ sess = create_session()
+ n1 = Node(data='n1')
+ n1.append(Node(data='n11'))
+ n1.append(Node(data='n12'))
+ n1.append(Node(data='n13'))
+ n1.children[1].append(Node(data='n121'))
+ n1.children[1].append(Node(data='n122'))
+ n1.children[1].append(Node(data='n123'))
+ sess.save(n1)
+ sess.flush()
+ sess.clear()
+ def go():
+ d = sess.query(Node).filter_by(data='n1').first()
+ assert Node(data='n1', children=[
+ Node(data='n11'),
+ Node(data='n12', children=[
+ Node(data='n121'),
+ Node(data='n122'),
+ Node(data='n123')
+ ]),
+ Node(data='n13')
+ ]) == d
+ self.assert_sql_count(testbase.db, go, 1)
+
+ def test_no_depth(self):
+ class Node(Base):
+ def append(self, node):
+ self.children.append(node)
+
+ mapper(Node, nodes, properties={
+ 'children':relation(Node, lazy=False)
+ })
+ sess = create_session()
+ n1 = Node(data='n1')
+ n1.append(Node(data='n11'))
+ n1.append(Node(data='n12'))
+ n1.append(Node(data='n13'))
+ n1.children[1].append(Node(data='n121'))
+ n1.children[1].append(Node(data='n122'))
+ n1.children[1].append(Node(data='n123'))
+ sess.save(n1)
+ sess.flush()
+ sess.clear()
+ def go():
+ d = sess.query(Node).filter_by(data='n1').first()
+ assert Node(data='n1', children=[
+ Node(data='n11'),
+ Node(data='n12', children=[
+ Node(data='n121'),
+ Node(data='n122'),
+ Node(data='n123')
+ ]),
+ Node(data='n13')
+ ]) == d
+ self.assert_sql_count(testbase.db, go, 3)
+
if __name__ == '__main__':
testbase.main()
diff --git a/test/orm/eagertest1.py b/test/orm/eagertest1.py
deleted file mode 100644
index 9765379f4..000000000
--- a/test/orm/eagertest1.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from testbase import PersistTest, AssertMixin
-import testbase
-import unittest, sys, os
-from sqlalchemy import *
-import datetime
-
-class EagerTest(AssertMixin):
- def setUpAll(self):
- global designType, design, part, inheritedPart
- designType = Table('design_types', testbase.metadata,
- Column('design_type_id', Integer, primary_key=True),
- )
-
- design =Table('design', testbase.metadata,
- Column('design_id', Integer, primary_key=True),
- Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
-
- part = Table('parts', testbase.metadata,
- Column('part_id', Integer, primary_key=True),
- Column('design_id', Integer, ForeignKey('design.design_id')),
- Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
-
- inheritedPart = Table('inherited_part', testbase.metadata,
- Column('ip_id', Integer, primary_key=True),
- Column('part_id', Integer, ForeignKey('parts.part_id')),
- Column('design_id', Integer, ForeignKey('design.design_id')),
- )
-
- testbase.metadata.create_all()
- def tearDownAll(self):
- testbase.metadata.drop_all()
- testbase.metadata.clear()
- def testone(self):
- class Part(object):pass
- class Design(object):pass
- class DesignType(object):pass
- class InheritedPart(object):pass
-
- mapper(Part, part)
-
- mapper(InheritedPart, inheritedPart, properties=dict(
- part=relation(Part, lazy=False)
- ))
-
- mapper(Design, design, properties=dict(
- parts=relation(Part, private=True, backref="design"),
- inheritedParts=relation(InheritedPart, private=True, backref="design"),
- ))
-
- mapper(DesignType, designType, properties=dict(
- # designs=relation(Design, private=True, backref="type"),
- ))
-
- class_mapper(Design).add_property("type", relation(DesignType, lazy=False, backref="designs"))
- class_mapper(Part).add_property("design", relation(Design, lazy=False, backref="parts"))
- #Part.mapper.add_property("designType", relation(DesignType))
-
- d = Design()
- sess = create_session()
- sess.save(d)
- sess.flush()
- sess.clear()
- x = sess.query(Design).get(1)
- x.inheritedParts
-
-if __name__ == "__main__":
- testbase.main()
-
-
diff --git a/test/orm/eagertest2.py b/test/orm/eagertest2.py
deleted file mode 100644
index 04de56f01..000000000
--- a/test/orm/eagertest2.py
+++ /dev/null
@@ -1,239 +0,0 @@
-from testbase import PersistTest, AssertMixin
-import testbase
-import unittest, sys, os
-from sqlalchemy import *
-import datetime
-from sqlalchemy.ext.sessioncontext import SessionContext
-
-class EagerTest(AssertMixin):
- def setUpAll(self):
- global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx, metadata
-
- metadata = MetaData(testbase.db)
- ctx = SessionContext(create_session)
-
- companies_table = Table('companies', metadata,
- Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True),
- Column('company_name', String(40)),
-
- )
-
- addresses_table = Table('addresses', metadata,
- Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
- Column('company_id', Integer, ForeignKey("companies.company_id")),
- Column('address', String(40)),
- )
-
- phones_table = Table('phone_numbers', metadata,
- Column('phone_id', Integer, Sequence('phone_id_seq', optional=True), primary_key = True),
- Column('address_id', Integer, ForeignKey('addresses.address_id')),
- Column('type', String(20)),
- Column('number', String(10)),
- )
-
- invoice_table = Table('invoices', metadata,
- Column('invoice_id', Integer, Sequence('invoice_id_seq', optional=True), primary_key = True),
- Column('company_id', Integer, ForeignKey("companies.company_id")),
- Column('date', DateTime),
- )
-
- items_table = Table('items', metadata,
- Column('item_id', Integer, Sequence('item_id_seq', optional=True), primary_key = True),
- Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')),
- Column('code', String(20)),
- Column('qty', Integer),
- )
-
- metadata.create_all()
-
- def tearDownAll(self):
- metadata.drop_all()
-
- def tearDown(self):
- clear_mappers()
- for t in metadata.table_iterator(reverse=True):
- t.delete().execute()
-
- def testone(self):
- """tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated
- the bug, which is that when the single Company is loaded, no further processing of the rows
- occurred in order to load the Company's second Address object."""
- class Company(object):
- def __init__(self):
- self.company_id = None
- def __repr__(self):
- return "Company:" + repr(getattr(self, 'company_id', None)) + " " + repr(getattr(self, 'company_name', None)) + " " + str([repr(addr) for addr in self.addresses])
-
- class Address(object):
- def __repr__(self):
- return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'company_id', None)) + " " + repr(self.address)
-
- class Invoice(object):
- def __init__(self):
- self.invoice_id = None
- def __repr__(self):
- return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None)) + " " + repr(self.company)
-
- mapper(Address, addresses_table, properties={
- }, extension=ctx.mapper_extension)
- mapper(Company, companies_table, properties={
- 'addresses' : relation(Address, lazy=False),
- }, extension=ctx.mapper_extension)
- mapper(Invoice, invoice_table, properties={
- 'company': relation(Company, lazy=False, )
- }, extension=ctx.mapper_extension)
-
- c1 = Company()
- c1.company_name = 'company 1'
- a1 = Address()
- a1.address = 'a1 address'
- c1.addresses.append(a1)
- a2 = Address()
- a2.address = 'a2 address'
- c1.addresses.append(a2)
- i1 = Invoice()
- i1.date = datetime.datetime.now()
- i1.company = c1
-
- ctx.current.flush()
-
- company_id = c1.company_id
- invoice_id = i1.invoice_id
-
- ctx.current.clear()
-
- c = ctx.current.query(Company).get(company_id)
-
- ctx.current.clear()
-
- i = ctx.current.query(Invoice).get(invoice_id)
-
- self.echo(repr(c))
- self.echo(repr(i.company))
- self.assert_(repr(c) == repr(i.company))
-
- def testtwo(self):
- """this is the original testcase that includes various complicating factors"""
- class Company(object):
- def __init__(self):
- self.company_id = None
- def __repr__(self):
- return "Company:" + repr(getattr(self, 'company_id', None)) + " " + repr(getattr(self, 'company_name', None)) + " " + str([repr(addr) for addr in self.addresses])
-
- class Address(object):
- def __repr__(self):
- return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'company_id', None)) + " " + repr(self.address) + str([repr(ph) for ph in self.phones])
-
- class Phone(object):
- def __repr__(self):
- return "Phone: " + repr(getattr(self, 'phone_id', None)) + " " + repr(getattr(self, 'address_id', None)) + " " + repr(self.type) + " " + repr(self.number)
-
- class Invoice(object):
- def __init__(self):
- self.invoice_id = None
- def __repr__(self):
- return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None)) + " " + repr(self.company) + " " + str([repr(item) for item in self.items])
-
- class Item(object):
- def __repr__(self):
- return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty)
-
- mapper(Phone, phones_table, extension=ctx.mapper_extension)
-
- mapper(Address, addresses_table, properties={
- 'phones': relation(Phone, lazy=False, backref='address')
- }, extension=ctx.mapper_extension)
-
- mapper(Company, companies_table, properties={
- 'addresses' : relation(Address, lazy=False, backref='company'),
- }, extension=ctx.mapper_extension)
-
- mapper(Item, items_table, extension=ctx.mapper_extension)
-
- mapper(Invoice, invoice_table, properties={
- 'items': relation(Item, lazy=False, backref='invoice'),
- 'company': relation(Company, lazy=False, backref='invoices')
- }, extension=ctx.mapper_extension)
-
- ctx.current.clear()
- c1 = Company()
- c1.company_name = 'company 1'
-
- a1 = Address()
- a1.address = 'a1 address'
-
- p1 = Phone()
- p1.type = 'home'
- p1.number = '1111'
-
- a1.phones.append(p1)
-
- p2 = Phone()
- p2.type = 'work'
- p2.number = '22222'
- a1.phones.append(p2)
-
- c1.addresses.append(a1)
-
- a2 = Address()
- a2.address = 'a2 address'
-
- p3 = Phone()
- p3.type = 'home'
- p3.number = '3333'
- a2.phones.append(p3)
-
- p4 = Phone()
- p4.type = 'work'
- p4.number = '44444'
- a2.phones.append(p4)
-
- c1.addresses.append(a2)
-
- ctx.current.flush()
-
- company_id = c1.company_id
-
- ctx.current.clear()
-
- a = ctx.current.query(Company).get(company_id)
- self.echo(repr(a))
-
- # set up an invoice
- i1 = Invoice()
- i1.date = datetime.datetime.now()
- i1.company = c1
-
- item1 = Item()
- item1.code = 'aaaa'
- item1.qty = 1
- item1.invoice = i1
-
- item2 = Item()
- item2.code = 'bbbb'
- item2.qty = 2
- item2.invoice = i1
-
- item3 = Item()
- item3.code = 'cccc'
- item3.qty = 3
- item3.invoice = i1
-
- ctx.current.flush()
-
- invoice_id = i1.invoice_id
-
- ctx.current.clear()
-
- c = ctx.current.query(Company).get(company_id)
- self.echo(repr(c))
-
- ctx.current.clear()
-
- i = ctx.current.query(Invoice).get(invoice_id)
- self.echo(repr(i))
-
- self.assert_(repr(i.company) == repr(c))
-
-if __name__ == "__main__":
- testbase.main()
diff --git a/test/orm/entity.py b/test/orm/entity.py
index 86486cafc..da76e8df0 100644
--- a/test/orm/entity.py
+++ b/test/orm/entity.py
@@ -1,11 +1,9 @@
-from testbase import PersistTest, AssertMixin
-import unittest
-from sqlalchemy import *
import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
-
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
class EntityTest(AssertMixin):
"""tests mappers that are constructed based on "entity names", which allows the same class
diff --git a/test/orm/fixtures.py b/test/orm/fixtures.py
index aed74c170..4a7d41459 100644
--- a/test/orm/fixtures.py
+++ b/test/orm/fixtures.py
@@ -1,4 +1,6 @@
+import testbase
from sqlalchemy import *
+from testlib import *
_recursion_stack = util.Set()
class Base(object):
@@ -35,7 +37,7 @@ class Base(object):
continue
else:
if value is not None:
- if value != getattr(other, attr):
+ if value != getattr(other, attr, None):
return False
else:
return True
diff --git a/test/orm/generative.py b/test/orm/generative.py
index 75280deed..4a90c13cb 100644
--- a/test/orm/generative.py
+++ b/test/orm/generative.py
@@ -1,16 +1,19 @@
-from testbase import PersistTest, AssertMixin, ORMTest
import testbase
-import tables
-
from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy import exceptions
+from testlib import *
+import testlib.tables as tables
+
+# TODO: these are more tests that should be updated to be part of test/orm/query.py
class Foo(object):
- pass
+ def __init__(self, **kwargs):
+ for k in kwargs:
+ setattr(self, k, kwargs[k])
class GenerativeQueryTest(PersistTest):
def setUpAll(self):
- self.install_threadlocal()
global foo, metadata
metadata = MetaData(testbase.db)
foo = Table('foo', metadata,
@@ -18,89 +21,103 @@ class GenerativeQueryTest(PersistTest):
Column('bar', Integer),
Column('range', Integer))
- assign_mapper(Foo, foo)
+ mapper(Foo, foo)
metadata.create_all()
+
+ sess = create_session(bind=testbase.db)
for i in range(100):
- Foo(bar=i, range=i%10)
- objectstore.flush()
+ sess.save(Foo(bar=i, range=i%10))
+ sess.flush()
- def setUp(self):
- self.query = Foo.query()
- self.orig = self.query.select_whereclause()
- self.res = self.query
-
def tearDownAll(self):
metadata.drop_all()
- self.uninstall_threadlocal()
clear_mappers()
def test_selectby(self):
- res = self.query.filter_by(range=5)
+ res = create_session(bind=testbase.db).query(Foo).filter_by(range=5)
assert res.order_by([Foo.c.bar])[0].bar == 5
assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
- @testbase.unsupported('mssql')
+ @testing.unsupported('mssql')
def test_slice(self):
- assert self.query[1] == self.orig[1]
- assert list(self.query[10:20]) == self.orig[10:20]
- assert list(self.query[10:]) == self.orig[10:]
- assert list(self.query[:10]) == self.orig[:10]
- assert list(self.query[:10]) == self.orig[:10]
- assert list(self.query[10:40:3]) == self.orig[10:40:3]
- assert list(self.query[-5:]) == self.orig[-5:]
- assert self.query[10:20][5] == self.orig[10:20][5]
+ sess = create_session(bind=testbase.db)
+ query = sess.query(Foo)
+ orig = query.all()
+ assert query[1] == orig[1]
+ assert list(query[10:20]) == orig[10:20]
+ assert list(query[10:]) == orig[10:]
+ assert list(query[:10]) == orig[:10]
+ assert list(query[:10]) == orig[:10]
+ assert list(query[10:40:3]) == orig[10:40:3]
+ assert list(query[-5:]) == orig[-5:]
+ assert query[10:20][5] == orig[10:20][5]
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_slice_mssql(self):
- assert list(self.query[:10]) == self.orig[:10]
- assert list(self.query[:10]) == self.orig[:10]
+ sess = create_session(bind=testbase.db)
+ query = sess.query(Foo)
+ orig = query.all()
+ assert list(query[:10]) == orig[:10]
+ assert list(query[:10]) == orig[:10]
def test_aggregate(self):
- assert self.query.count() == 100
- assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0
- assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29
- assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).scalar() == 29
+ sess = create_session(bind=testbase.db)
+ query = sess.query(Foo)
+ assert query.count() == 100
+ assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
+ assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29
+ assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29
+ assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_aggregate_1(self):
# this one fails in mysql as the result comes back as a string
- assert self.query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
+ query = create_session(bind=testbase.db).query(Foo)
+ assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
- @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
+ @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_2(self):
- assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
+ query = create_session(bind=testbase.db).query(Foo)
+ assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
- @testbase.supported('postgres', 'mysql', 'firebird', 'mssql')
+ @testing.supported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_2_int(self):
- assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
+ query = create_session(bind=testbase.db).query(Foo)
+ assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
- @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
+ @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_3(self):
- assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).scalar() == 14.5
+ query = create_session(bind=testbase.db).query(Foo)
+ assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5
+ assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() == 14.5
def test_filter(self):
- assert self.query.count() == 100
- assert self.query.filter(Foo.c.bar < 30).count() == 30
- res2 = self.query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
+ query = create_session(bind=testbase.db).query(Foo)
+ assert query.count() == 100
+ assert query.filter(Foo.c.bar < 30).count() == 30
+ res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
assert res2.count() == 19
def test_options(self):
+ query = create_session(bind=testbase.db).query(Foo)
class ext1(MapperExtension):
- def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
+ def populate_instance(self, mapper, selectcontext, row, instance, **flags):
instance.TEST = "hello world"
return EXT_PASS
- objectstore.clear()
- assert self.res.options(extension(ext1()))[0].TEST == "hello world"
+ assert query.options(extension(ext1()))[0].TEST == "hello world"
def test_order_by(self):
- assert self.res.order_by([Foo.c.bar])[0].bar == 0
- assert self.res.order_by([desc(Foo.c.bar)])[0].bar == 99
+ query = create_session(bind=testbase.db).query(Foo)
+ assert query.order_by([Foo.c.bar])[0].bar == 0
+ assert query.order_by([desc(Foo.c.bar)])[0].bar == 99
def test_offset(self):
- assert list(self.res.order_by([Foo.c.bar]).offset(10))[0].bar == 10
+ query = create_session(bind=testbase.db).query(Foo)
+ assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10
def test_offset(self):
- assert len(list(self.res.limit(10))) == 10
+ query = create_session(bind=testbase.db).query(Foo)
+ assert len(list(query.limit(10))) == 10
class Obj1(object):
pass
@@ -109,9 +126,8 @@ class Obj2(object):
class GenerativeTest2(PersistTest):
def setUpAll(self):
- self.install_threadlocal()
global metadata, table1, table2
- metadata = MetaData(testbase.db)
+ metadata = MetaData()
table1 = Table('Table1', metadata,
Column('id', Integer, primary_key=True),
)
@@ -119,29 +135,23 @@ class GenerativeTest2(PersistTest):
Column('t1id', Integer, ForeignKey("Table1.id"), primary_key=True),
Column('num', Integer, primary_key=True),
)
- assign_mapper(Obj1, table1)
- assign_mapper(Obj2, table2)
- metadata.create_all()
- table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4})
- table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
+ mapper(Obj1, table1)
+ mapper(Obj2, table2)
+ metadata.create_all(bind=testbase.db)
+ testbase.db.execute(table1.insert(), {'id':1},{'id':2},{'id':3},{'id':4})
+ testbase.db.execute(table2.insert(), {'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3})
- def setUp(self):
- self.query = Query(Obj1)
- #self.orig = self.query.select_whereclause()
- #self.res = self.query.select()
-
def tearDownAll(self):
- metadata.drop_all()
- self.uninstall_threadlocal()
+ metadata.drop_all(bind=testbase.db)
clear_mappers()
def test_distinctcount(self):
- res = self.query
- assert res.count() == 4
- res = self.query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
+ query = create_session(bind=testbase.db).query(Obj1)
+ assert query.count() == 4
+ res = query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
assert res.count() == 3
- res = self.query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)).distinct()
+ res = query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)).distinct()
self.assertEqual(res.count(), 1)
class RelationsTest(AssertMixin):
@@ -159,9 +169,9 @@ class RelationsTest(AssertMixin):
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
- x = query.join('orders').join('items').filter(tables.Item.c.item_id==2)
+ x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2)
print x.compile()
self.assert_result(list(x), tables.User, tables.user_result[2])
def test_outerjointo(self):
@@ -171,9 +181,9 @@ class RelationsTest(AssertMixin):
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
- x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
print x.compile()
self.assert_result(list(x), tables.User, *tables.user_result[1:3])
def test_outerjointo_count(self):
@@ -183,9 +193,9 @@ class RelationsTest(AssertMixin):
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
- x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
assert x==2
def test_from(self):
mapper(tables.User, tables.users, properties={
@@ -193,7 +203,7 @@ class RelationsTest(AssertMixin):
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\
filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
@@ -203,7 +213,6 @@ class RelationsTest(AssertMixin):
class CaseSensitiveTest(PersistTest):
def setUpAll(self):
- self.install_threadlocal()
global metadata, table1, table2
metadata = MetaData(testbase.db)
table1 = Table('Table1', metadata,
@@ -213,29 +222,23 @@ class CaseSensitiveTest(PersistTest):
Column('T1ID', Integer, ForeignKey("Table1.ID"), primary_key=True),
Column('NUM', Integer, primary_key=True),
)
- assign_mapper(Obj1, table1)
- assign_mapper(Obj2, table2)
+ mapper(Obj1, table1)
+ mapper(Obj2, table2)
metadata.create_all()
table1.insert().execute({'ID':1},{'ID':2},{'ID':3},{'ID':4})
table2.insert().execute({'NUM':1,'T1ID':1},{'NUM':2,'T1ID':1},{'NUM':3,'T1ID':1},\
{'NUM':4,'T1ID':2},{'NUM':5,'T1ID':2},{'NUM':6,'T1ID':3})
- def setUp(self):
- self.query = Query(Obj1)
- #self.orig = self.query.select_whereclause()
- #self.res = self.query.select()
-
def tearDownAll(self):
metadata.drop_all()
- self.uninstall_threadlocal()
clear_mappers()
def test_distinctcount(self):
- res = self.query
- assert res.count() == 4
- res = self.query.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1))
+ q = create_session(bind=testbase.db).query(Obj1)
+ assert q.count() == 4
+ res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1))
assert res.count() == 3
- res = self.query.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)).distinct()
+ res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)).distinct()
self.assertEqual(res.count(), 1)
class SelfRefTest(ORMTest):
@@ -248,18 +251,18 @@ class SelfRefTest(ORMTest):
def test_noautojoin(self):
class T(object):pass
mapper(T, t1, properties={'children':relation(T)})
- sess = create_session()
+ sess = create_session(bind=testbase.db)
try:
sess.query(T).join('children').select_by(id=7)
assert False
except exceptions.InvalidRequestError, e:
- assert str(e) == "Self-referential query on 'T.children (T)' property must be constructed manually using an Alias object for the related table.", str(e)
+ assert str(e) == "Self-referential query on 'T.children (T)' property requires create_aliases=True argument.", str(e)
try:
sess.query(T).join(['children']).select_by(id=7)
assert False
except exceptions.InvalidRequestError, e:
- assert str(e) == "Self-referential query on 'T.children (T)' property must be constructed manually using an Alias object for the related table.", str(e)
+ assert str(e) == "Self-referential query on 'T.children (T)' property requires create_aliases=True argument.", str(e)
diff --git a/test/orm/inheritance/__init__.py b/test/orm/inheritance/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/test/orm/inheritance/__init__.py
diff --git a/test/orm/abc_inheritance.py b/test/orm/inheritance/abc_inheritance.py
index 7689bd543..3b35b3713 100644
--- a/test/orm/abc_inheritance.py
+++ b/test/orm/inheritance/abc_inheritance.py
@@ -1,12 +1,14 @@
+import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE
-import testbase
+from testlib import *
def produce_test(parent, child, direction):
"""produce a testcase for A->B->C inheritance with a self-referential
relationship between two of the classes, using either one-to-many or
many-to-one."""
- class ABCTest(testbase.ORMTest):
+ class ABCTest(ORMTest):
def define_tables(self, meta):
global ta, tb, tc
ta = ["a", meta]
@@ -53,6 +55,8 @@ def produce_test(parent, child, direction):
parent_table = {"a":ta, "b":tb, "c": tc}[parent]
child_table = {"a":ta, "b":tb, "c": tc}[child]
+ remote_side = None
+
if direction == MANYTOONE:
foreign_keys = [parent_table.c.child_id]
elif direction == ONETOMANY:
@@ -65,6 +69,8 @@ def produce_test(parent, child, direction):
relationjoin = parent_table.c.id==child_table.c.parent_id
elif direction == MANYTOONE:
relationjoin = parent_table.c.child_id==child_table.c.id
+ if parent is child:
+ remote_side = [child_table.c.id]
abcjoin = polymorphic_union(
{"a":ta.select(tb.c.id==None, from_obj=[ta.outerjoin(tb, onclause=atob)]),
@@ -79,19 +85,15 @@ def produce_test(parent, child, direction):
"c":tc.join(tb, onclause=btoc).join(ta, onclause=atob)
},"type", "bcjoin"
)
-
- class A(object):pass
+ class A(object):
+ def __init__(self, name):
+ self.a_data = name
class B(A):pass
class C(B):pass
- mapper(A, ta, polymorphic_on=abcjoin.c.type, select_table=abcjoin, polymorphic_identity="a", )
- mapper(B, tb, polymorphic_on=bcjoin.c.type, select_table=bcjoin, polymorphic_identity="b", inherits=A, inherit_condition=atob,)
- mapper(C, tc, polymorphic_identity="c", inherits=B, inherit_condition=btoc, )
-
- #print "KEYS:"
- #print [c.key for c in class_mapper(A).primary_key]
- #print [c.key for c in class_mapper(B).primary_key]
- #print [c.key for c in class_mapper(C).primary_key]
+ mapper(A, ta, polymorphic_on=abcjoin.c.type, select_table=abcjoin, polymorphic_identity="a")
+ mapper(B, tb, polymorphic_on=bcjoin.c.type, select_table=bcjoin, polymorphic_identity="b", inherits=A, inherit_condition=atob)
+ mapper(C, tc, polymorphic_identity="c", inherits=B, inherit_condition=btoc)
parent_mapper = class_mapper({ta:A, tb:B, tc:C}[parent_table])
child_mapper = class_mapper({ta:A, tb:B, tc:C}[child_table])
@@ -99,24 +101,24 @@ def produce_test(parent, child, direction):
parent_class = parent_mapper.class_
child_class = child_mapper.class_
- parent_mapper.add_property("collection", relation(child_mapper, primaryjoin=relationjoin, foreign_keys=foreign_keys, uselist=True))
+ parent_mapper.add_property("collection", relation(child_mapper, primaryjoin=relationjoin, foreign_keys=foreign_keys, remote_side=remote_side, uselist=True))
sess = create_session()
- parent_obj = parent_class()
- child_obj = child_class()
- somea = A()
- someb = B()
- somec = C()
+ parent_obj = parent_class('parent1')
+ child_obj = child_class('child1')
+ somea = A('somea')
+ someb = B('someb')
+ somec = C('somec')
print "APPENDING", parent.__class__.__name__ , "TO", child.__class__.__name__
sess.save(parent_obj)
parent_obj.collection.append(child_obj)
if direction == ONETOMANY:
- child2 = child_class()
+ child2 = child_class('child2')
parent_obj.collection.append(child2)
sess.save(child2)
elif direction == MANYTOONE:
- parent2 = parent_class()
+ parent2 = parent_class('parent2')
parent2.collection.append(child_obj)
sess.save(parent2)
sess.save(somea)
@@ -155,8 +157,6 @@ def produce_test(parent, child, direction):
# test all combinations of polymorphic a/b/c related to another of a/b/c
for parent in ["a", "b", "c"]:
for child in ["a", "b", "c"]:
- if parent == child:
- continue
for direction in [ONETOMANY, MANYTOONE]:
testclass = produce_test(parent, child, direction)
exec("%s = testclass" % testclass.__name__)
diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py
new file mode 100644
index 000000000..1ab10c060
--- /dev/null
+++ b/test/orm/inheritance/alltests.py
@@ -0,0 +1,28 @@
+import testbase
+import unittest
+
+def suite():
+ modules_to_test = (
+ 'orm.inheritance.basic',
+ 'orm.inheritance.manytomany',
+ 'orm.inheritance.single',
+ 'orm.inheritance.concrete',
+ 'orm.inheritance.polymorph',
+ 'orm.inheritance.polymorph2',
+ 'orm.inheritance.poly_linked_list',
+ 'orm.inheritance.abc_inheritance',
+ 'orm.inheritance.productspec',
+ 'orm.inheritance.magazine',
+
+ )
+ alltests = unittest.TestSuite()
+ for name in modules_to_test:
+ mod = __import__(name)
+ for token in name.split('.')[1:]:
+ mod = getattr(mod, token)
+ alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
+ return alltests
+
+
+if __name__ == '__main__':
+ testbase.main(suite())
diff --git a/test/orm/inheritance.py b/test/orm/inheritance/basic.py
index 2281a0597..be623e1b8 100644
--- a/test/orm/inheritance.py
+++ b/test/orm/inheritance/basic.py
@@ -1,159 +1,13 @@
import testbase
from sqlalchemy import *
-import string
-import sys
+from sqlalchemy.orm import *
+from testlib import *
-class Principal( object ):
- def __init__(self, **kwargs):
- for key, value in kwargs.iteritems():
- setattr(self, key, value)
-class User( Principal ):
- pass
-
-class Group( Principal ):
- pass
-
-class InheritTest(testbase.ORMTest):
- """deals with inheritance and many-to-many relationships"""
- def define_tables(self, metadata):
- global principals
- global users
- global groups
- global user_group_map
-
- principals = Table(
- 'principals',
- metadata,
- Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True),
- Column('name', String(50), nullable=False),
- )
-
- users = Table(
- 'prin_users',
- metadata,
- Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
- Column('password', String(50), nullable=False),
- Column('email', String(50), nullable=False),
- Column('login_id', String(50), nullable=False),
-
- )
-
- groups = Table(
- 'prin_groups',
- metadata,
- Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
-
- )
-
- user_group_map = Table(
- 'prin_user_group_map',
- metadata,
- Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ),
- Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ),
- #Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), ),
- #Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), ),
-
- )
-
- def testbasic(self):
- mapper( Principal, principals )
- mapper(
- User,
- users,
- inherits=Principal
- )
-
- mapper(
- Group,
- groups,
- inherits=Principal,
- properties=dict( users = relation(User, secondary=user_group_map, lazy=True, backref="groups") )
- )
-
- g = Group(name="group1")
- g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1"))
- sess = create_session()
- sess.save(g)
- sess.flush()
- # TODO: put an assertion
-
-class InheritTest2(testbase.ORMTest):
- """deals with inheritance and many-to-many relationships"""
- def define_tables(self, metadata):
- global foo, bar, foo_bar
- foo = Table('foo', metadata,
- Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
- Column('data', String(20)),
- )
-
- bar = Table('bar', metadata,
- Column('bid', Integer, ForeignKey('foo.id'), primary_key=True),
- #Column('fid', Integer, ForeignKey('foo.id'), )
- )
-
- foo_bar = Table('foo_bar', metadata,
- Column('foo_id', Integer, ForeignKey('foo.id')),
- Column('bar_id', Integer, ForeignKey('bar.bid')))
-
- def testget(self):
- class Foo(object):
- def __init__(self, data=None):
- self.data = data
- class Bar(Foo):pass
-
- mapper(Foo, foo)
- mapper(Bar, bar, inherits=Foo)
-
- b = Bar('somedata')
- sess = create_session()
- sess.save(b)
- sess.flush()
- sess.clear()
-
- # test that "bar.bid" does not need to be referenced in a get
- # (ticket 185)
- assert sess.query(Bar).get(b.id).id == b.id
-
- def testbasic(self):
- class Foo(object):
- def __init__(self, data=None):
- self.data = data
-
- mapper(Foo, foo)
- class Bar(Foo):
- pass
-
- mapper(Bar, bar, inherits=Foo, properties={
- 'foos': relation(Foo, secondary=foo_bar, lazy=False)
- })
-
- sess = create_session()
- b = Bar('barfoo')
- sess.save(b)
- sess.flush()
-
- f1 = Foo('subfoo1')
- f2 = Foo('subfoo2')
- b.foos.append(f1)
- b.foos.append(f2)
-
- sess.flush()
- sess.clear()
-
- l = sess.query(Bar).select()
- print l[0]
- print l[0].foos
- self.assert_result(l, Bar,
-# {'id':1, 'data':'barfoo', 'bid':1, 'foos':(Foo, [{'id':2,'data':'subfoo1'}, {'id':3,'data':'subfoo2'}])},
- {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])},
- )
-
-class InheritTest3(testbase.ORMTest):
- """deals with inheritance and many-to-many relationships"""
+class O2MTest(ORMTest):
+ """deals with inheritance and one-to-many relationships"""
def define_tables(self, metadata):
- global foo, bar, blub, bar_foo, blub_bar, blub_foo
-
+ global foo, bar, blub
# the 'data' columns are to appease SQLite which cant handle a blank INSERT
foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq'), primary_key=True),
@@ -165,20 +19,9 @@ class InheritTest3(testbase.ORMTest):
blub = Table('blub', metadata,
Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
+ Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
Column('data', String(20)))
- bar_foo = Table('bar_foo', metadata,
- Column('bar_id', Integer, ForeignKey('bar.id')),
- Column('foo_id', Integer, ForeignKey('foo.id')))
-
- blub_bar = Table('bar_blub', metadata,
- Column('blub_id', Integer, ForeignKey('blub.id')),
- Column('bar_id', Integer, ForeignKey('bar.id')))
-
- blub_foo = Table('blub_foo', metadata,
- Column('blub_id', Integer, ForeignKey('blub.id')),
- Column('foo_id', Integer, ForeignKey('foo.id')))
-
def testbasic(self):
class Foo(object):
def __init__(self, data=None):
@@ -190,71 +33,41 @@ class InheritTest3(testbase.ORMTest):
class Bar(Foo):
def __repr__(self):
return "Bar id %d, data %s" % (self.id, self.data)
-
- mapper(Bar, bar, inherits=Foo, properties={
- 'foos' :relation(Foo, secondary=bar_foo, lazy=True)
- })
-
- sess = create_session()
- b = Bar('bar #1', _sa_session=sess)
- b.foos.append(Foo("foo #1"))
- b.foos.append(Foo("foo #2"))
- sess.flush()
- compare = repr(b) + repr(b.foos)
- sess.clear()
- l = sess.query(Bar).select()
- self.echo(repr(l[0]) + repr(l[0].foos))
- self.assert_(repr(l[0]) + repr(l[0].foos) == compare)
-
- def testadvanced(self):
- class Foo(object):
- def __init__(self, data=None):
- self.data = data
- def __repr__(self):
- return "Foo id %d, data %s" % (self.id, self.data)
- mapper(Foo, foo)
- class Bar(Foo):
- def __repr__(self):
- return "Bar id %d, data %s" % (self.id, self.data)
mapper(Bar, bar, inherits=Foo)
-
+
class Blub(Bar):
def __repr__(self):
- return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos]))
-
+ return "Blub id %d, data %s" % (self.id, self.data)
+
mapper(Blub, blub, inherits=Bar, properties={
- 'bars':relation(Bar, secondary=blub_bar, lazy=False),
- 'foos':relation(Foo, secondary=blub_foo, lazy=False),
+ 'parent_foo':relation(Foo)
})
sess = create_session()
- f1 = Foo("foo #1", _sa_session=sess)
- b1 = Bar("bar #1", _sa_session=sess)
- b2 = Bar("bar #2", _sa_session=sess)
- bl1 = Blub("blub #1", _sa_session=sess)
- bl1.foos.append(f1)
- bl1.bars.append(b2)
+ b1 = Blub("blub #1")
+ b2 = Blub("blub #2")
+ f = Foo("foo #1")
+ sess.save(b1)
+ sess.save(b2)
+ sess.save(f)
+ b1.parent_foo = f
+ b2.parent_foo = f
sess.flush()
- compare = repr(bl1)
- blubid = bl1.id
+ compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo)
sess.clear()
-
l = sess.query(Blub).select()
- self.echo(l)
- self.assert_(repr(l[0]) == compare)
- sess.clear()
- x = sess.query(Blub).get_by(id=blubid)
- self.echo(x)
- self.assert_(repr(x) == compare)
-
-class InheritTest4(testbase.ORMTest):
- """deals with inheritance and one-to-many relationships"""
+ result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo)
+ print result
+ self.assert_(compare == result)
+ self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
+
+class GetTest(ORMTest):
def define_tables(self, metadata):
global foo, bar, blub
- # the 'data' columns are to appease SQLite which cant handle a blank INSERT
foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+ Column('type', String(30)),
Column('data', String(20)))
bar = Table('bar', metadata,
@@ -262,50 +75,80 @@ class InheritTest4(testbase.ORMTest):
Column('data', String(20)))
blub = Table('blub', metadata,
- Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
- Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
+ Column('id', Integer, primary_key=True),
+ Column('foo_id', Integer, ForeignKey('foo.id')),
+ Column('bar_id', Integer, ForeignKey('bar.id')),
Column('data', String(20)))
+
+ def create_test(polymorphic):
+ def test_get(self):
+ class Foo(object):
+ pass
- def testbasic(self):
- class Foo(object):
- def __init__(self, data=None):
- self.data = data
- def __repr__(self):
- return "Foo id %d, data %s" % (self.id, self.data)
- mapper(Foo, foo)
-
- class Bar(Foo):
- def __repr__(self):
- return "Bar id %d, data %s" % (self.id, self.data)
-
- mapper(Bar, bar, inherits=Foo)
+ class Bar(Foo):
+ pass
- class Blub(Bar):
- def __repr__(self):
- return "Blub id %d, data %s" % (self.id, self.data)
-
- mapper(Blub, blub, inherits=Bar, properties={
- 'parent_foo':relation(Foo)
- })
+ class Blub(Bar):
+ pass
+
+ if polymorphic:
+ mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
+ mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
+ mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
+ else:
+ mapper(Foo, foo)
+ mapper(Bar, bar, inherits=Foo)
+ mapper(Blub, blub, inherits=Bar)
+
+ sess = create_session()
+ f = Foo()
+ b = Bar()
+ bl = Blub()
+ sess.save(f)
+ sess.save(b)
+ sess.save(bl)
+ sess.flush()
+
+ if polymorphic:
+ def go():
+ assert sess.query(Foo).get(f.id) == f
+ assert sess.query(Foo).get(b.id) == b
+ assert sess.query(Foo).get(bl.id) == bl
+ assert sess.query(Bar).get(b.id) == b
+ assert sess.query(Bar).get(bl.id) == bl
+ assert sess.query(Blub).get(bl.id) == bl
+
+ self.assert_sql_count(testbase.db, go, 0)
+ else:
+ # this is testing the 'wrong' behavior of using get()
+ # polymorphically with mappers that are not configured to be
+ # polymorphic. the important part being that get() always
+ # returns an instance of the query's type.
+ def go():
+ assert sess.query(Foo).get(f.id) == f
+
+ bb = sess.query(Foo).get(b.id)
+ assert isinstance(b, Foo) and bb.id==b.id
+
+ bll = sess.query(Foo).get(bl.id)
+ assert isinstance(bll, Foo) and bll.id==bl.id
+
+ assert sess.query(Bar).get(b.id) == b
+
+ bll = sess.query(Bar).get(bl.id)
+ assert isinstance(bll, Bar) and bll.id == bl.id
+
+ assert sess.query(Blub).get(bl.id) == bl
+
+ self.assert_sql_count(testbase.db, go, 3)
+
+ return test_get
+
+ test_get_polymorphic = create_test(True)
+ test_get_nonpolymorphic = create_test(False)
- sess = create_session()
- b1 = Blub("blub #1", _sa_session=sess)
- b2 = Blub("blub #2", _sa_session=sess)
- f = Foo("foo #1", _sa_session=sess)
- b1.parent_foo = f
- b2.parent_foo = f
- sess.flush()
- compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo)
- sess.clear()
- l = sess.query(Blub).select()
- result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo)
- self.echo(result)
- self.assert_(compare == result)
- self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
-class InheritTest5(testbase.ORMTest):
- """testing that construction of inheriting mappers works regardless of when extra properties
- are added to the superclass mapper"""
+class ConstructionTest(ORMTest):
def define_tables(self, metadata):
global content_type, content, product
content_type = Table('content_type', metadata,
@@ -313,7 +156,8 @@ class InheritTest5(testbase.ORMTest):
)
content = Table('content', metadata,
Column('id', Integer, primary_key=True),
- Column('content_type_id', Integer, ForeignKey('content_type.id'))
+ Column('content_type_id', Integer, ForeignKey('content_type.id')),
+ Column('type', String(30))
)
product = Table('product', metadata,
Column('id', Integer, ForeignKey('content.id'), primary_key=True)
@@ -327,11 +171,15 @@ class InheritTest5(testbase.ORMTest):
content_types = mapper(ContentType, content_type)
contents = mapper(Content, content, properties={
'content_type':relation(content_types)
- })
- #contents.add_property('content_type', relation(content_types)) #adding this makes the inheritance stop working
- # shouldnt throw exception
- products = mapper(Product, product, inherits=contents)
- # TODO: assertion ??
+ }, polymorphic_identity='contents')
+
+ products = mapper(Product, product, inherits=contents, polymorphic_identity='products')
+
+ try:
+ compile_mappers()
+ assert False
+ except exceptions.ArgumentError, e:
+ assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
def testbackref(self):
"""tests adding a property to the superclass mapper"""
@@ -339,8 +187,8 @@ class InheritTest5(testbase.ORMTest):
class Content(object): pass
class Product(Content): pass
- contents = mapper(Content, content)
- products = mapper(Product, product, inherits=contents)
+ contents = mapper(Content, content, polymorphic_on=content.c.type, polymorphic_identity='content')
+ products = mapper(Product, product, inherits=contents, polymorphic_identity='product')
content_types = mapper(ContentType, content_type, properties={
'content':relation(contents, backref='contenttype')
})
@@ -348,7 +196,7 @@ class InheritTest5(testbase.ORMTest):
p.contenttype = ContentType()
# TODO: assertion ??
-class InheritTest6(testbase.ORMTest):
+class EagerLazyTest(ORMTest):
"""tests eager load/lazy load of child items off inheritance mappers, tests that
LazyLoader constructs the right query condition."""
def define_tables(self, metadata):
@@ -370,7 +218,6 @@ class InheritTest6(testbase.ORMTest):
foos = mapper(Foo, foo)
bars = mapper(Bar, bar, inherits=foos)
bars.add_property('lazy', relation(foos, bar_foo, lazy=True))
- print bars.props['lazy'].primaryjoin, bars.props['lazy'].secondaryjoin
bars.add_property('eager', relation(foos, bar_foo, lazy=False))
foo.insert().execute(data='foo1')
@@ -391,7 +238,7 @@ class InheritTest6(testbase.ORMTest):
self.assert_(len(q.selectfirst().eager) == 1)
-class InheritTest7(testbase.ORMTest):
+class FlushTest(ORMTest):
"""test dependency sorting among inheriting mappers"""
def define_tables(self, metadata):
global users, roles, user_roles, admins
@@ -412,22 +259,20 @@ class InheritTest7(testbase.ORMTest):
)
admins = Table('admin', metadata,
- Column('id', Integer, primary_key=True),
+ Column('admin_id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('users.id'))
)
def testone(self):
class User(object):pass
- class Role(object):
- def __init__(self, description):
- self.description = description
+ class Role(object):pass
class Admin(User):pass
role_mapper = mapper(Role, roles)
user_mapper = mapper(User, users, properties = {
'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
}
)
- admin_mapper = mapper(Admin, admins, inherits=user_mapper, properties={'aid':admins.c.id})
+ admin_mapper = mapper(Admin, admins, inherits=user_mapper)
sess = create_session()
adminrole = Role('admin')
sess.save(adminrole)
@@ -435,7 +280,7 @@ class InheritTest7(testbase.ORMTest):
# create an Admin, and append a Role. the dependency processors
# corresponding to the "roles" attribute for the Admin mapper and the User mapper
- # have to insure that two dependency processors dont fire off and insert the
+ # have to ensure that two dependency processors dont fire off and insert the
# many to many row twice.
a = Admin()
a.roles.append(adminrole)
@@ -463,7 +308,7 @@ class InheritTest7(testbase.ORMTest):
}
)
- admin_mapper = mapper(Admin, admins, inherits=user_mapper, properties={'aid':admins.c.id})
+ admin_mapper = mapper(Admin, admins, inherits=user_mapper)
# create roles
adminrole = Role('admin')
@@ -482,14 +327,14 @@ class InheritTest7(testbase.ORMTest):
sess.flush()
assert user_roles.count().scalar() == 1
-class InheritTest8(testbase.ORMTest):
+class DistinctPKTest(ORMTest):
"""test the construction of mapper.primary_key when an inheriting relationship
joins on a column other than primary key column."""
keep_data = True
-
+
def define_tables(self, metadata):
global person_table, employee_table, Person, Employee
-
+
person_table = Table("persons", metadata,
Column("id", Integer, primary_key=True),
Column("name", String(80)),
@@ -509,7 +354,7 @@ class InheritTest8(testbase.ORMTest):
import warnings
warnings.filterwarnings("error", r".*On mapper.*distinct primary key")
-
+
def insert_data(self):
person_insert = person_table.insert()
person_insert.execute(id=1, name='alice')
@@ -518,22 +363,17 @@ class InheritTest8(testbase.ORMTest):
employee_insert = employee_table.insert()
employee_insert.execute(id=2, salary=250, person_id=1) # alice
employee_insert.execute(id=3, salary=200, person_id=2) # bob
-
+
def test_implicit(self):
person_mapper = mapper(Person, person_table)
mapper(Employee, employee_table, inherits=person_mapper)
- try:
- print class_mapper(Employee).primary_key
- assert list(class_mapper(Employee).primary_key) == [person_table.c.id, employee_table.c.id]
- assert False
- except RuntimeWarning, e:
- assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name."
+ assert list(class_mapper(Employee).primary_key) == [person_table.c.id]
def test_explicit_props(self):
person_mapper = mapper(Person, person_table)
mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id})
self._do_test(True)
-
+
def test_explicit_composite_pk(self):
person_mapper = mapper(Person, person_table)
mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
@@ -547,17 +387,12 @@ class InheritTest8(testbase.ORMTest):
person_mapper = mapper(Person, person_table)
mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id])
self._do_test(False)
-
+
def _do_test(self, composite):
session = create_session()
query = session.query(Employee)
if composite:
- try:
- query.get(1)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Could not find enough values to formulate primary key for query.get(); primary key columns are 'persons.id', 'employees.id'"
alice1 = query.get([1,2])
bob = query.get([2,3])
alice2 = query.get([1,2])
@@ -565,11 +400,10 @@ class InheritTest8(testbase.ORMTest):
alice1 = query.get(1)
bob = query.get(2)
alice2 = query.get(1)
-
+
assert alice1.name == alice2.name == 'alice'
assert bob.name == 'bob'
-
-
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/inheritance4.py b/test/orm/inheritance/concrete.py
index 9f4e275ae..d95a96da5 100644
--- a/test/orm/inheritance4.py
+++ b/test/orm/inheritance/concrete.py
@@ -1,7 +1,9 @@
-from sqlalchemy import *
import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
-class ConcreteTest1(testbase.ORMTest):
+class ConcreteTest1(ORMTest):
def define_tables(self, metadata):
global managers_table, engineers_table
managers_table = Table('managers', metadata,
@@ -52,6 +54,7 @@ class ConcreteTest1(testbase.ORMTest):
session.flush()
session.clear()
+ print set([repr(x) for x in session.query(Employee).select()])
assert set([repr(x) for x in session.query(Employee).select()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"])
assert set([repr(x) for x in session.query(Manager).select()]) == set(["Manager Tom knows how to manage things"])
assert set([repr(x) for x in session.query(Engineer).select()]) == set(["Engineer Kurt knows how to hack"])
@@ -63,4 +66,4 @@ class ConcreteTest1(testbase.ORMTest):
if __name__ == '__main__':
- testbase.main() \ No newline at end of file
+ testbase.main()
diff --git a/test/orm/inheritance3.py b/test/orm/inheritance/magazine.py
index a9c88ef60..a0bf24148 100644
--- a/test/orm/inheritance3.py
+++ b/test/orm/inheritance/magazine.py
@@ -1,5 +1,8 @@
import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
class BaseObject(object):
def __init__(self, *args, **kwargs):
@@ -20,7 +23,7 @@ class Location(BaseObject):
def _set_name(self, name):
session = create_session()
- s = session.query(LocationName).selectfirst(location_name_table.c.name==name)
+ s = session.query(LocationName).filter(LocationName.name==name).first()
session.clear()
if s is not None:
self._name = s
@@ -64,8 +67,8 @@ class MagazinePage(Page):
class ClassifiedPage(MagazinePage):
pass
-class InheritTest(testbase.ORMTest):
- """tests a large polymorphic relationship"""
+
+class MagazineTest(ORMTest):
def define_tables(self, metadata):
global publication_table, issue_table, location_table, location_name_table, magazine_table, \
page_table, magazine_page_table, classified_page_table, page_size_table
@@ -116,6 +119,8 @@ class InheritTest(testbase.ORMTest):
Column('name', String(45), default=''),
)
+def generate_round_trip_test(use_unions=False, use_joins=False):
+ def test_roundtrip(self):
publication_mapper = mapper(Publication, publication_table)
issue_mapper = mapper(Issue, issue_table, properties = {
@@ -133,33 +138,50 @@ class InheritTest(testbase.ORMTest):
page_size_mapper = mapper(PageSize, page_size_table)
- page_join = polymorphic_union(
- {
- 'm': page_table.join(magazine_page_table),
- 'c': page_table.join(magazine_page_table).join(classified_page_table),
- 'p': page_table.select(page_table.c.type=='p'),
- }, None, 'page_join')
-
- magazine_join = polymorphic_union(
- {
- 'm': page_table.join(magazine_page_table),
- 'c': page_table.join(magazine_page_table).join(classified_page_table),
- }, None, 'page_join')
-
magazine_mapper = mapper(Magazine, magazine_table, properties = {
'location': relation(Location, backref=backref('magazine', uselist=False)),
'size': relation(PageSize),
})
- page_mapper = mapper(Page, page_table, select_table=page_join, polymorphic_on=page_join.c.type, polymorphic_identity='p')
+ if use_unions:
+ page_join = polymorphic_union(
+ {
+ 'm': page_table.join(magazine_page_table),
+ 'c': page_table.join(magazine_page_table).join(classified_page_table),
+ 'p': page_table.select(page_table.c.type=='p'),
+ }, None, 'page_join')
+ page_mapper = mapper(Page, page_table, select_table=page_join, polymorphic_on=page_join.c.type, polymorphic_identity='p')
+ elif use_joins:
+ page_join = page_table.outerjoin(magazine_page_table).outerjoin(classified_page_table)
+ page_mapper = mapper(Page, page_table, select_table=page_join, polymorphic_on=page_table.c.type, polymorphic_identity='p')
+ else:
+ page_mapper = mapper(Page, page_table, polymorphic_on=page_table.c.type, polymorphic_identity='p')
+
+ if use_unions:
+ magazine_join = polymorphic_union(
+ {
+ 'm': page_table.join(magazine_page_table),
+ 'c': page_table.join(magazine_page_table).join(classified_page_table),
+ }, None, 'page_join')
+ magazine_page_mapper = mapper(MagazinePage, magazine_page_table, select_table=magazine_join, inherits=page_mapper, polymorphic_identity='m', properties={
+ 'magazine': relation(Magazine, backref=backref('pages', order_by=magazine_join.c.page_no))
+ })
+ elif use_joins:
+ magazine_join = page_table.join(magazine_page_table).outerjoin(classified_page_table)
+ magazine_page_mapper = mapper(MagazinePage, magazine_page_table, select_table=magazine_join, inherits=page_mapper, polymorphic_identity='m', properties={
+ 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no))
+ })
+ else:
+ magazine_page_mapper = mapper(MagazinePage, magazine_page_table, inherits=page_mapper, polymorphic_identity='m', properties={
+ 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no))
+ })
+
+ classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id])
+ #compile_mappers()
+ #print [str(s) for s in classified_page_mapper.primary_key]
+ #print classified_page_mapper.columntoproperty[page_table.c.id]
- magazine_page_mapper = mapper(MagazinePage, magazine_page_table, select_table=magazine_join, inherits=page_mapper, polymorphic_identity='m', properties={
- 'magazine': relation(Magazine, backref=backref('pages', order_by=magazine_join.c.page_no))
- })
-
- classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c')
- def testone(self):
session = create_session()
pub = Publication(name='Test')
@@ -174,18 +196,25 @@ class InheritTest(testbase.ORMTest):
page2 = MagazinePage(magazine=magazine,page_no=2)
page3 = ClassifiedPage(magazine=magazine,page_no=3)
session.save(pub)
-
+
session.flush()
print [x for x in session]
session.clear()
session.flush()
session.clear()
- p = session.query(Publication).selectone_by(name='Test')
+ p = session.query(Publication).filter(Publication.name=="Test").one()
print p.issues[0].locations[0].magazine.pages
print [page, page2, page3]
- assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3])
+ assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages)
+
+ test_roundtrip.__name__ = "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")
+ setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip)
+
+for (use_union, use_join) in [(True, False), (False, True), (False, False)]:
+ generate_round_trip_test(use_union, use_join)
+
if __name__ == '__main__':
testbase.main()
diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py
new file mode 100644
index 000000000..df00f39d0
--- /dev/null
+++ b/test/orm/inheritance/manytomany.py
@@ -0,0 +1,255 @@
+import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
+
+class InheritTest(ORMTest):
+ """deals with inheritance and many-to-many relationships"""
+ def define_tables(self, metadata):
+ global principals
+ global users
+ global groups
+ global user_group_map
+
+ principals = Table(
+ 'principals',
+ metadata,
+ Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True),
+ Column('name', String(50), nullable=False),
+ )
+
+ users = Table(
+ 'prin_users',
+ metadata,
+ Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
+ Column('password', String(50), nullable=False),
+ Column('email', String(50), nullable=False),
+ Column('login_id', String(50), nullable=False),
+
+ )
+
+ groups = Table(
+ 'prin_groups',
+ metadata,
+ Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
+
+ )
+
+ user_group_map = Table(
+ 'prin_user_group_map',
+ metadata,
+ Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ),
+ Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ),
+ )
+
+ def testbasic(self):
+ class Principal(object):
+ def __init__(self, **kwargs):
+ for key, value in kwargs.iteritems():
+ setattr(self, key, value)
+
+ class User(Principal):
+ pass
+
+ class Group(Principal):
+ pass
+
+ mapper(Principal, principals)
+ mapper(
+ User,
+ users,
+ inherits=Principal
+ )
+
+ mapper(
+ Group,
+ groups,
+ inherits=Principal,
+ properties=dict( users = relation(User, secondary=user_group_map, lazy=True, backref="groups") )
+ )
+
+ g = Group(name="group1")
+ g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1"))
+ sess = create_session()
+ sess.save(g)
+ sess.flush()
+ # TODO: put an assertion
+
+class InheritTest2(ORMTest):
+ """deals with inheritance and many-to-many relationships"""
+ def define_tables(self, metadata):
+ global foo, bar, foo_bar
+ foo = Table('foo', metadata,
+ Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
+ Column('data', String(20)),
+ )
+
+ bar = Table('bar', metadata,
+ Column('bid', Integer, ForeignKey('foo.id'), primary_key=True),
+ #Column('fid', Integer, ForeignKey('foo.id'), )
+ )
+
+ foo_bar = Table('foo_bar', metadata,
+ Column('foo_id', Integer, ForeignKey('foo.id')),
+ Column('bar_id', Integer, ForeignKey('bar.bid')))
+
+ def testget(self):
+ class Foo(object):pass
+ def __init__(self, data=None):
+ self.data = data
+ class Bar(Foo):pass
+
+ mapper(Foo, foo)
+ mapper(Bar, bar, inherits=Foo)
+ print foo.join(bar).primary_key
+ print class_mapper(Bar).primary_key
+ b = Bar('somedata')
+ sess = create_session()
+ sess.save(b)
+ sess.flush()
+ sess.clear()
+
+ # test that "bar.bid" does not need to be referenced in a get
+ # (ticket 185)
+ assert sess.query(Bar).get(b.id).id == b.id
+
+ def testbasic(self):
+ class Foo(object):
+ def __init__(self, data=None):
+ self.data = data
+
+ mapper(Foo, foo)
+ class Bar(Foo):
+ pass
+
+ mapper(Bar, bar, inherits=Foo, properties={
+ 'foos': relation(Foo, secondary=foo_bar, lazy=False)
+ })
+
+ sess = create_session()
+ b = Bar('barfoo')
+ sess.save(b)
+ sess.flush()
+
+ f1 = Foo('subfoo1')
+ f2 = Foo('subfoo2')
+ b.foos.append(f1)
+ b.foos.append(f2)
+
+ sess.flush()
+ sess.clear()
+
+ l = sess.query(Bar).select()
+ print l[0]
+ print l[0].foos
+ self.assert_result(l, Bar,
+# {'id':1, 'data':'barfoo', 'bid':1, 'foos':(Foo, [{'id':2,'data':'subfoo1'}, {'id':3,'data':'subfoo2'}])},
+ {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])},
+ )
+
+class InheritTest3(ORMTest):
+ """deals with inheritance and many-to-many relationships"""
+ def define_tables(self, metadata):
+ global foo, bar, blub, bar_foo, blub_bar, blub_foo
+
+ # the 'data' columns are to appease SQLite which cant handle a blank INSERT
+ foo = Table('foo', metadata,
+ Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+ Column('data', String(20)))
+
+ bar = Table('bar', metadata,
+ Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(20)))
+
+ blub = Table('blub', metadata,
+ Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
+ Column('data', String(20)))
+
+ bar_foo = Table('bar_foo', metadata,
+ Column('bar_id', Integer, ForeignKey('bar.id')),
+ Column('foo_id', Integer, ForeignKey('foo.id')))
+
+ blub_bar = Table('bar_blub', metadata,
+ Column('blub_id', Integer, ForeignKey('blub.id')),
+ Column('bar_id', Integer, ForeignKey('bar.id')))
+
+ blub_foo = Table('blub_foo', metadata,
+ Column('blub_id', Integer, ForeignKey('blub.id')),
+ Column('foo_id', Integer, ForeignKey('foo.id')))
+
+ def testbasic(self):
+ class Foo(object):
+ def __init__(self, data=None):
+ self.data = data
+ def __repr__(self):
+ return "Foo id %d, data %s" % (self.id, self.data)
+ mapper(Foo, foo)
+
+ class Bar(Foo):
+ def __repr__(self):
+ return "Bar id %d, data %s" % (self.id, self.data)
+
+ mapper(Bar, bar, inherits=Foo, properties={
+ 'foos' :relation(Foo, secondary=bar_foo, lazy=True)
+ })
+
+ sess = create_session()
+ b = Bar('bar #1')
+ sess.save(b)
+ b.foos.append(Foo("foo #1"))
+ b.foos.append(Foo("foo #2"))
+ sess.flush()
+ compare = repr(b) + repr(b.foos)
+ sess.clear()
+ l = sess.query(Bar).select()
+ print repr(l[0]) + repr(l[0].foos)
+ self.assert_(repr(l[0]) + repr(l[0].foos) == compare)
+
+ def testadvanced(self):
+ class Foo(object):
+ def __init__(self, data=None):
+ self.data = data
+ def __repr__(self):
+ return "Foo id %d, data %s" % (self.id, self.data)
+ mapper(Foo, foo)
+
+ class Bar(Foo):
+ def __repr__(self):
+ return "Bar id %d, data %s" % (self.id, self.data)
+ mapper(Bar, bar, inherits=Foo)
+
+ class Blub(Bar):
+ def __repr__(self):
+ return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos]))
+
+ mapper(Blub, blub, inherits=Bar, properties={
+ 'bars':relation(Bar, secondary=blub_bar, lazy=False),
+ 'foos':relation(Foo, secondary=blub_foo, lazy=False),
+ })
+
+ sess = create_session()
+ f1 = Foo("foo #1")
+ b1 = Bar("bar #1")
+ b2 = Bar("bar #2")
+ bl1 = Blub("blub #1")
+ for o in (f1, b1, b2, bl1):
+ sess.save(o)
+ bl1.foos.append(f1)
+ bl1.bars.append(b2)
+ sess.flush()
+ compare = repr(bl1)
+ blubid = bl1.id
+ sess.clear()
+
+ l = sess.query(Blub).select()
+ print l
+ self.assert_(repr(l[0]) == compare)
+ sess.clear()
+ x = sess.query(Blub).get_by(id=blubid)
+ print x
+ self.assert_(repr(x) == compare)
+
+
+if __name__ == "__main__":
+ testbase.main()
diff --git a/test/orm/poly_linked_list.py b/test/orm/inheritance/poly_linked_list.py
index 30cda4bb6..7297002f5 100644
--- a/test/orm/poly_linked_list.py
+++ b/test/orm/inheritance/poly_linked_list.py
@@ -1,7 +1,10 @@
import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
-class PolymorphicCircularTest(testbase.ORMTest):
+
+class PolymorphicCircularTest(ORMTest):
keep_mappers = True
def define_tables(self, metadata):
global Table1, Table1B, Table2, Table3, Data
@@ -26,14 +29,15 @@ class PolymorphicCircularTest(testbase.ORMTest):
Column('data', String(30))
)
- join = polymorphic_union(
- {
- 'table3' : table1.join(table3),
- 'table2' : table1.join(table2),
- 'table1' : table1.select(table1.c.type.in_('table1', 'table1b')),
- }, None, 'pjoin')
-
- # still with us so far ?
+ #join = polymorphic_union(
+ # {
+ # 'table3' : table1.join(table3),
+ # 'table2' : table1.join(table2),
+ # 'table1' : table1.select(table1.c.type.in_('table1', 'table1b')),
+ # }, None, 'pjoin')
+
+ join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin')
+ #join = None
class Table1(object):
def __init__(self, name, data=None):
@@ -59,10 +63,10 @@ class PolymorphicCircularTest(testbase.ORMTest):
return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data)))
try:
- # this is how the mapping used to work. insure that this raises an error now
+ # this is how the mapping used to work. ensure that this raises an error now
table1_mapper = mapper(Table1, table1,
select_table=join,
- polymorphic_on=join.c.type,
+ polymorphic_on=table1.c.type,
polymorphic_identity='table1',
properties={
'next': relation(Table1,
@@ -83,8 +87,8 @@ class PolymorphicCircularTest(testbase.ORMTest):
# exception now. since eager loading would never work for that relation anyway, its better that the user
# gets an exception instead of it silently not eager loading.
table1_mapper = mapper(Table1, table1,
- select_table=join,
- polymorphic_on=join.c.type,
+ #select_table=join,
+ polymorphic_on=table1.c.type,
polymorphic_identity='table1',
properties={
'next': relation(Table1,
@@ -101,7 +105,10 @@ class PolymorphicCircularTest(testbase.ORMTest):
polymorphic_identity='table2')
table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3')
-
+
+ table1_mapper.compile()
+ assert table1_mapper.primary_key == [table1.c.id], table1_mapper.primary_key
+
def testone(self):
self.do_testlist([Table1, Table2, Table1, Table2])
diff --git a/test/orm/polymorph.py b/test/orm/inheritance/polymorph.py
index 9d886cf3f..3eb2e032f 100644
--- a/test/orm/polymorph.py
+++ b/test/orm/inheritance/polymorph.py
@@ -1,8 +1,11 @@
+"""tests basic polymorphic mapper loading/saving, minimal relations"""
+
import testbase
-from sqlalchemy import *
import sets
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
-# tests basic polymorphic mapper loading/saving, minimal relations
class Person(object):
def __init__(self, **kwargs):
@@ -21,6 +24,10 @@ class Engineer(Person):
class Manager(Person):
def __repr__(self):
return "Manager %s, status %s, manager_name %s" % (self.get_name(), self.status, self.manager_name)
+class Boss(Manager):
+ def __repr__(self):
+ return "Boss %s, status %s, manager_name %s golf swing %s" % (self.get_name(), self.status, self.manager_name, self.golf_swing)
+
class Company(object):
def __init__(self, **kwargs):
for key, value in kwargs.iteritems():
@@ -28,9 +35,9 @@ class Company(object):
def __repr__(self):
return "Company %s" % self.name
-class PolymorphTest(testbase.ORMTest):
+class PolymorphTest(ORMTest):
def define_tables(self, metadata):
- global companies, people, engineers, managers
+ global companies, people, engineers, managers, boss
# a table to store companies
companies = Table('companies', metadata,
@@ -58,6 +65,11 @@ class PolymorphTest(testbase.ORMTest):
Column('manager_name', String(50))
)
+ boss = Table('boss', metadata,
+ Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True),
+ Column('golf_swing', String(30)),
+ )
+
metadata.create_all()
class CompileTest(PolymorphTest):
@@ -100,29 +112,6 @@ class CompileTest(PolymorphTest):
#person_mapper.compile()
class_mapper(Manager).compile()
- def testcompile3(self):
- """test that a mapper referencing an inheriting mapper in a self-referential relationship does
- not allow an eager load to be set up."""
- person_join = polymorphic_union( {
- 'engineer':people.join(engineers),
- 'manager':people.join(managers),
- 'person':people.select(people.c.type=='person'),
- }, None, 'pjoin')
-
- person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type,
- polymorphic_identity='person',
- properties = dict(managers = relation(Manager, lazy=False))
- )
-
- mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
- mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
-
- try:
- class_mapper(Manager).compile()
- assert False
- except exceptions.ArgumentError:
- assert True
-
class InsertOrderTest(PolymorphTest):
def test_insert_order(self):
"""test that classes of multiple types mix up mapper inserts
@@ -191,8 +180,11 @@ class RelationToSubclassTest(PolymorphTest):
sess.query(Company).get_by(company_id=c.company_id)
assert sets.Set([e.get_name() for e in c.managers]) == sets.Set(['pointy haired boss'])
assert c.managers[0].company is c
-
-def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False):
+
+class RoundTripTest(PolymorphTest):
+ pass
+
+def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None, use_outer_joins=False):
"""generates a round trip test.
include_base - whether or not to include the base 'person' type in the union.
@@ -200,117 +192,145 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
redefine_colprop - if we redefine the 'name' column to be 'people_name' on the base Person class
use_literal_join - primary join condition is explicitly specified
"""
- class RoundTripTest(PolymorphTest):
- def test_roundtrip(self):
- # create a union that represents both types of joins.
- if include_base:
+ def test_roundtrip(self):
+ # create a union that represents both types of joins.
+ if not polymorphic_fetch == 'union':
+ person_join = None
+ manager_join = None
+ elif include_base:
+ if use_outer_joins:
+ person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
+ manager_join = people.join(managers).outerjoin(boss)
+ else:
person_join = polymorphic_union(
{
'engineer':people.join(engineers),
'manager':people.join(managers),
'person':people.select(people.c.type=='person'),
}, None, 'pjoin')
+
+ manager_join = people.join(managers).outerjoin(boss)
+ else:
+ if use_outer_joins:
+ person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
+ manager_join = people.join(managers).outerjoin(boss)
else:
person_join = polymorphic_union(
{
'engineer':people.join(engineers),
'manager':people.join(managers),
}, None, 'pjoin')
+ manager_join = people.join(managers).outerjoin(boss)
- if redefine_colprop:
- person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
- else:
- person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person')
-
- mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
- mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
-
- if use_literal_join:
- mapper(Company, companies, properties={
- 'employees': relation(Person, lazy=lazy_relation, primaryjoin=people.c.company_id==companies.c.company_id, private=True,
- backref="company"
- )
- })
- else:
- mapper(Company, companies, properties={
- 'employees': relation(Person, lazy=lazy_relation, private=True,
- backref="company"
- )
- })
-
- if redefine_colprop:
- person_attribute_name = 'person_name'
- else:
- person_attribute_name = 'name'
+ if redefine_colprop:
+ person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
+ else:
+ person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person')
- session = create_session()
- c = Company(name='company1')
- c.employees.append(Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'}))
- c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'}))
- if include_base:
- c.employees.append(Person(status='HHH', **{person_attribute_name:'joesmith'}))
- c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'}))
- c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'}))
- session.save(c)
- print session.new
- session.flush()
- session.clear()
- id = c.company_id
- c = session.query(Company).get(id)
- for e in c.employees:
- print e, e._instance_key, e.company
- if include_base:
- assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith'])
- else:
- assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'wally', 'jsmith'])
- print "\n"
+ mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
+ mapper(Manager, managers, inherits=person_mapper, select_table=manager_join, polymorphic_identity='manager')
+ mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss')
- # test selecting from the query, using the base mapped table (people) as the selection criterion.
- # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join"
- dilbert = session.query(Person).selectfirst(people.c.name=='dilbert')
- dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert')
- assert dilbert is dilbert2
-
- # test selecting from the query, joining against an alias of the base "people" table. test that
- # the "palias" alias does *not* get sucked up into the "person_join" conversion.
- palias = people.alias("palias")
- session.query(Person).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id))
- dilbert2 = session.query(Engineer).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id))
- assert dilbert is dilbert2
-
- session.query(Person).selectfirst((engineers.c.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id))
- dilbert2 = session.query(Engineer).selectfirst(engineers.c.engineer_name=="engineer1")
- assert dilbert is dilbert2
+ if use_literal_join:
+ mapper(Company, companies, properties={
+ 'employees': relation(Person, lazy=lazy_relation, primaryjoin=people.c.company_id==companies.c.company_id, private=True,
+ backref="company"
+ )
+ })
+ else:
+ mapper(Company, companies, properties={
+ 'employees': relation(Person, lazy=lazy_relation, private=True,
+ backref="company"
+ )
+ })
+ if redefine_colprop:
+ person_attribute_name = 'person_name'
+ else:
+ person_attribute_name = 'name'
+
+ session = create_session()
+ c = Company(name='company1')
+ c.employees.append(Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'}))
+ c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'}))
+ if include_base:
+ c.employees.append(Person(status='HHH', **{person_attribute_name:'joesmith'}))
+ c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'}))
+ c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'}))
+ session.save(c)
+ print session.new
+ session.flush()
+ session.clear()
+ id = c.company_id
+ c = session.query(Company).get(id)
+ for e in c.employees:
+ print e, e._instance_key, e.company
+ if include_base:
+ assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')])
+ else:
+ assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')])
+ print "\n"
+
+ # test selecting from the query, using the base mapped table (people) as the selection criterion.
+ # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join"
+ dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ dilbert2 = session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ assert dilbert is dilbert2
+
+ # test selecting from the query, joining against an alias of the base "people" table. test that
+ # the "palias" alias does *not* get sucked up into the "person_join" conversion.
+ palias = people.alias("palias")
+ session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
+ dilbert2 = session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
+ assert dilbert is dilbert2
+
+ session.query(Person).filter((Engineer.engineer_name=="engineer1") & (Engineer.person_id==people.c.person_id)).first()
- dilbert.engineer_name = 'hes dibert!'
-
- session.flush()
- session.clear()
+ dilbert2 = session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0]
+ assert dilbert is dilbert2
+
+ dilbert.engineer_name = 'hes dibert!'
- c = session.query(Company).get(id)
- for e in c.employees:
- print e, e._instance_key
+ session.flush()
+ session.clear()
- session.delete(c)
- session.flush()
+ # save/load some managers/bosses
+ b = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})
+ session.save(b)
+ session.flush()
+ session.clear()
+ c = session.query(Manager).all()
+ assert sets.Set([repr(x) for x in c]) == sets.Set(["Manager pointy haired boss, status AAB, manager_name manager1", "Manager jsmith, status ABA, manager_name manager2", "Boss daboss, status BBB, manager_name boss golf swing fore"]), repr([repr(x) for x in c])
+
+ c = session.query(Company).get(id)
+ for e in c.employees:
+ print e, e._instance_key
- RoundTripTest.__name__ = "Test%s%s%s%s" % (
- (lazy_relation and "Lazy" or "Eager"),
- (include_base and "Inclbase" or ""),
- (redefine_colprop and "Redefcol" or ""),
- (use_literal_join and "Litjoin" or "")
+ session.delete(c)
+ session.flush()
+
+
+ test_roundtrip.__name__ = "test_%s%s%s%s%s" % (
+ (lazy_relation and "lazy" or "eager"),
+ (include_base and "_inclbase" or ""),
+ (redefine_colprop and "_redefcol" or ""),
+ (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
+ (use_outer_joins and '_outerjoins' or '')
)
- return RoundTripTest
+ setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
for include_base in [True, False]:
for lazy_relation in [True, False]:
for redefine_colprop in [True, False]:
for use_literal_join in [True, False]:
- testclass = generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join)
- exec("%s = testclass" % testclass.__name__)
-
+ for polymorphic_fetch in ['union', 'select', 'deferred']:
+ if polymorphic_fetch == 'union':
+ for use_outer_joins in [True, False]:
+ generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, use_outer_joins)
+ else:
+ generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, False)
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/inheritance5.py b/test/orm/inheritance/polymorph2.py
index cf7224fa4..a2f9c4a5f 100644
--- a/test/orm/inheritance5.py
+++ b/test/orm/inheritance/polymorph2.py
@@ -1,5 +1,8 @@
-from sqlalchemy import *
import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
class AttrSettable(object):
def __init__(self, **kwargs):
@@ -8,7 +11,7 @@ class AttrSettable(object):
return self.__class__.__name__ + "(%s)" % (hex(id(self)))
-class RelationTest1(testbase.ORMTest):
+class RelationTest1(ORMTest):
"""test self-referential relationships on polymorphic mappers"""
def define_tables(self, metadata):
global people, managers
@@ -88,7 +91,7 @@ class RelationTest1(testbase.ORMTest):
print p, m, m.employee
assert m.employee is p
-class RelationTest2(testbase.ORMTest):
+class RelationTest2(ORMTest):
"""test self-referential relationships on polymorphic mappers"""
def define_tables(self, metadata):
global people, managers, data
@@ -116,6 +119,10 @@ class RelationTest2(testbase.ORMTest):
self.do_test("join1", True)
def testrelationonsubclass_j2_data(self):
self.do_test("join2", True)
+ def testrelationonsubclass_j3_nodata(self):
+ self.do_test("join3", False)
+ def testrelationonsubclass_j3_data(self):
+ self.do_test("join3", True)
def do_test(self, jointype="join1", usedata=False):
class Person(AttrSettable):
@@ -128,19 +135,24 @@ class RelationTest2(testbase.ORMTest):
'person':people.select(people.c.type=='person'),
'manager':join(people, managers, people.c.person_id==managers.c.person_id)
}, None)
+ polymorphic_on=poly_union.c.type
elif jointype == "join2":
poly_union = polymorphic_union({
'person':people.select(people.c.type=='person'),
'manager':managers.join(people, people.c.person_id==managers.c.person_id)
}, None)
-
+ polymorphic_on=poly_union.c.type
+ elif jointype == "join3":
+ poly_union = None
+ polymorphic_on = people.c.type
+
if usedata:
class Data(object):
def __init__(self, data):
self.data = data
mapper(Data, data)
- mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=poly_union.c.type)
+ mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=polymorphic_on)
if usedata:
mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager',
@@ -174,7 +186,7 @@ class RelationTest2(testbase.ORMTest):
if usedata:
assert m.data.data == 'ms data'
-class RelationTest3(testbase.ORMTest):
+class RelationTest3(ORMTest):
"""test self-referential relationships on polymorphic mappers"""
def define_tables(self, metadata):
global people, managers, data
@@ -194,16 +206,8 @@ class RelationTest3(testbase.ORMTest):
Column('data', String(30))
)
- def testrelationonbaseclass_j1_nodata(self):
- self.do_test("join1", False)
- def testrelationonbaseclass_j2_nodata(self):
- self.do_test("join2", False)
- def testrelationonbaseclass_j1_data(self):
- self.do_test("join1", True)
- def testrelationonbaseclass_j2_data(self):
- self.do_test("join2", True)
-
- def do_test(self, jointype="join1", usedata=False):
+def generate_test(jointype="join1", usedata=False):
+ def do_test(self):
class Person(AttrSettable):
pass
class Manager(Person):
@@ -224,10 +228,14 @@ class RelationTest3(testbase.ORMTest):
'manager':join(people, managers, people.c.person_id==managers.c.person_id),
'person':people.select(people.c.type=='person')
}, None)
-
+ elif jointype == 'join3':
+ poly_union = people.outerjoin(managers)
+ elif jointype == "join4":
+ poly_union=None
+
if usedata:
mapper(Data, data)
-
+
mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager')
if usedata:
mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type,
@@ -258,7 +266,7 @@ class RelationTest3(testbase.ORMTest):
sess.save(m)
sess.save(p)
sess.flush()
-
+
sess.clear()
p = sess.query(Person).get(p.person_id)
p2 = sess.query(Person).get(p2.person_id)
@@ -271,9 +279,17 @@ class RelationTest3(testbase.ORMTest):
if usedata:
assert p.data.data == 'ps data'
assert m.data.data == 'ms data'
+
+ do_test.__name__ = 'test_relationonbaseclass_%s_%s' % (jointype, data and "nodata" or "data")
+ return do_test
+for jointype in ["join1", "join2", "join3", "join4"]:
+ for data in (True, False):
+ func = generate_test(jointype, data)
+ setattr(RelationTest3, func.__name__, func)
+
-class RelationTest4(testbase.ORMTest):
+class RelationTest4(ORMTest):
def define_tables(self, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
@@ -329,6 +345,9 @@ class RelationTest4(testbase.ORMTest):
manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper)})
+ print class_mapper(Person).primary_key
+ print person_mapper.get_select_mapper().primary_key
+
# so the primaryjoin is "people.c.person_id==cars.c.owner". the "lazy" clause will be
# "people.c.person_id=?". the employee_join is two selects union'ed together, one of which
# will contain employee.c.person_id the other contains manager.c.person_id. people.c.person_id is not explicitly in
@@ -350,8 +369,8 @@ class RelationTest4(testbase.ORMTest):
session.flush()
- engineer4 = session.query(Engineer).selectfirst_by(name="E4")
- manager3 = session.query(Manager).selectfirst_by(name="M3")
+ engineer4 = session.query(Engineer).filter(Engineer.name=="E4").first()
+ manager3 = session.query(Manager).filter(Manager.name=="M3").first()
car1 = Car(employee=engineer4)
session.save(car1)
@@ -361,27 +380,32 @@ class RelationTest4(testbase.ORMTest):
session.clear()
+ print "----------------------------"
car1 = session.query(Car).get(car1.car_id)
+ print "----------------------------"
usingGet = session.query(person_mapper).get(car1.owner)
+ print "----------------------------"
usingProperty = car1.employee
+ print "----------------------------"
# All print should output the same person (engineer E4)
assert str(engineer4) == "Engineer E4, status X"
+ print str(usingGet)
assert str(usingGet) == "Engineer E4, status X"
assert str(usingProperty) == "Engineer E4, status X"
session.clear()
-
+ print "-----------------------------------------------------------------"
# and now for the lightning round, eager !
car1 = session.query(Car).options(eagerload('employee')).get(car1.car_id)
assert str(car1.employee) == "Engineer E4, status X"
session.clear()
s = session.query(Car)
- c = s.join("employee").select(employee_join.c.name=="E4")[0]
+ c = s.join("employee").filter(Person.name=="E4")[0]
assert c.car_id==car1.car_id
-class RelationTest5(testbase.ORMTest):
+class RelationTest5(ORMTest):
def define_tables(self, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
@@ -441,7 +465,7 @@ class RelationTest5(testbase.ORMTest):
assert carlist[0].manager is None
assert carlist[1].manager.person_id == car2.manager.person_id
-class RelationTest6(testbase.ORMTest):
+class RelationTest6(ORMTest):
"""test self-referential relationships on a single joined-table inheritance mapper"""
def define_tables(self, metadata):
global people, managers, data
@@ -484,7 +508,7 @@ class RelationTest6(testbase.ORMTest):
m2 = sess.query(Manager).get(m2.person_id)
assert m.colleague is m2
-class RelationTest7(testbase.ORMTest):
+class RelationTest7(ORMTest):
def define_tables(self, metadata):
global people, engineers, managers, cars, offroad_cars
cars = Table('cars', metadata,
@@ -583,7 +607,7 @@ class RelationTest7(testbase.ORMTest):
for p in r:
assert p.car_id == p.car.car_id
-class GenerativeTest(testbase.AssertMixin):
+class GenerativeTest(AssertMixin):
def setUpAll(self):
# cars---owned by--- people (abstract) --- has a --- status
# | ^ ^ |
@@ -698,7 +722,7 @@ class GenerativeTest(testbase.AssertMixin):
# test these twice because theres caching involved, as well previous issues that modified the polymorphic union
for x in range(0, 2):
- r = session.query(Person).filter_by(people.c.name.like('%2')).join('status').filter_by(name="active")
+ r = session.query(Person).filter(people.c.name.like('%2')).join('status').filter_by(name="active")
assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
r = session.query(Engineer).join('status').filter(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active"))
assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
@@ -709,22 +733,22 @@ class GenerativeTest(testbase.AssertMixin):
r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id))
assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
-class MultiLevelTest(testbase.ORMTest):
+class MultiLevelTest(ORMTest):
def define_tables(self, metadata):
global table_Employee, table_Engineer, table_Manager
table_Employee = Table( 'Employee', metadata,
- Column( 'name', type= String(100), ),
- Column( 'id', primary_key= True, type= Integer, ),
- Column( 'atype', type= String(100), ),
+ Column( 'name', type_= String(100), ),
+ Column( 'id', primary_key= True, type_= Integer, ),
+ Column( 'atype', type_= String(100), ),
)
table_Engineer = Table( 'Engineer', metadata,
- Column( 'machine', type= String(100), ),
+ Column( 'machine', type_= String(100), ),
Column( 'id', Integer, ForeignKey( 'Employee.id', ), primary_key= True, ),
)
table_Manager = Table( 'Manager', metadata,
- Column( 'duties', type= String(100), ),
+ Column( 'duties', type_= String(100), ),
Column( 'id', Integer, ForeignKey( 'Engineer.id', ), primary_key= True, ),
)
def test_threelevels(self):
@@ -786,7 +810,7 @@ class MultiLevelTest(testbase.ORMTest):
assert set(session.query( Engineer).select()) == set([b,c])
assert session.query( Manager).select() == [c]
-class ManyToManyPolyTest(testbase.ORMTest):
+class ManyToManyPolyTest(ORMTest):
def define_tables(self, metadata):
global base_item_table, item_table, base_item_collection_table, collection_table
base_item_table = Table(
@@ -836,7 +860,7 @@ class ManyToManyPolyTest(testbase.ORMTest):
class_mapper(BaseItem)
-class CustomPKTest(testbase.ORMTest):
+class CustomPKTest(ORMTest):
def define_tables(self, metadata):
global t1, t2
t1 = Table('t1', metadata,
@@ -847,7 +871,7 @@ class CustomPKTest(testbase.ORMTest):
t2 = Table('t2', metadata,
Column('t2id', Integer, ForeignKey('t1.id'), primary_key=True),
Column('t2data', String(30)))
-
+
def test_custompk(self):
"""test that the primary_key attribute is propigated to the polymorphic mapper"""
@@ -885,6 +909,48 @@ class CustomPKTest(testbase.ORMTest):
ot1 = sess.query(T1).get(ot1.id)
ot1.data = 'hi'
sess.flush()
+
+ def test_pk_collapses(self):
+ """test that a composite primary key attribute formed by a join is "collapsed" into its
+ minimal columns"""
+
+ class T1(object):pass
+ class T2(T1):pass
+
+ # create a polymorphic union with the select against the base table first.
+ # with the join being second, the alias of the union will
+ # pick up two "primary key" columns. technically the alias should have a
+ # 2-col pk in any case but the leading select has a NULL for the "t2id" column
+ d = util.OrderedDict()
+ d['t1'] = t1.select(t1.c.type=='t1')
+ d['t2'] = t1.join(t2)
+ pjoin = polymorphic_union(d, None, 'pjoin')
+
+ #print pjoin.original.primary_key
+ #print pjoin.primary_key
+ assert len(pjoin.primary_key) == 2
+
+ mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', select_table=pjoin)
+ mapper(T2, t2, inherits=T1, polymorphic_identity='t2')
+ assert len(class_mapper(T1).primary_key) == 1
+ assert len(class_mapper(T1).get_select_mapper().compile().primary_key) == 1
+
+ print [str(c) for c in class_mapper(T1).primary_key]
+ ot1 = T1()
+ ot2 = T2()
+ sess = create_session()
+ sess.save(ot1)
+ sess.save(ot2)
+ sess.flush()
+ sess.clear()
+
+ # query using get(), using only one value. this requires the select_table mapper
+ # has the same single-col primary key.
+ assert sess.query(T1).get(ot1.id).id == ot1.id
+
+ ot1 = sess.query(T1).get(ot1.id)
+ ot1.data = 'hi'
+ sess.flush()
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/inheritance2.py b/test/orm/inheritance/productspec.py
index 906526456..2459cd36e 100644
--- a/test/orm/inheritance2.py
+++ b/test/orm/inheritance/productspec.py
@@ -1,8 +1,11 @@
import testbase
-from sqlalchemy import *
from datetime import datetime
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
-class InheritTest(testbase.ORMTest):
+class InheritTest(ORMTest):
"""tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
def define_tables(self, metadata):
global products_table, specification_table, documents_table
diff --git a/test/orm/single.py b/test/orm/inheritance/single.py
index 31a90da21..68fe821af 100644
--- a/test/orm/single.py
+++ b/test/orm/inheritance/single.py
@@ -1,7 +1,10 @@
-from sqlalchemy import *
import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
-class SingleInheritanceTest(testbase.AssertMixin):
+class SingleInheritanceTest(AssertMixin):
def setUpAll(self):
metadata = MetaData(testbase.db)
global employees_table
diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py
index e9d77e09c..6684c6288 100644
--- a/test/orm/lazy_relations.py
+++ b/test/orm/lazy_relations.py
@@ -1,9 +1,9 @@
"""basic tests of lazy loaded attributes"""
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-import testbase
-
+from testlib import *
from fixtures import *
from query import QueryTest
diff --git a/test/orm/lazytest1.py b/test/orm/lazytest1.py
index 2cabac3a2..b5296120b 100644
--- a/test/orm/lazytest1.py
+++ b/test/orm/lazytest1.py
@@ -1,8 +1,7 @@
-from testbase import PersistTest, AssertMixin
import testbase
-import unittest, sys, os
from sqlalchemy import *
-import datetime
+from sqlalchemy.orm import *
+from testlib import *
class LazyTest(AssertMixin):
def setUpAll(self):
diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py
index f6e9197a2..8b310f86c 100644
--- a/test/orm/manytomany.py
+++ b/test/orm/manytomany.py
@@ -1,6 +1,8 @@
import testbase
from sqlalchemy import *
-import string
+from sqlalchemy.orm import *
+from testlib import *
+
class Place(object):
'''represents a place'''
@@ -25,7 +27,7 @@ class Transition(object):
def __repr__(self):
return object.__repr__(self)+ " " + repr(self.inputs) + " " + repr(self.outputs)
-class M2MTest(testbase.ORMTest):
+class M2MTest(ORMTest):
def define_tables(self, metadata):
global place
place = Table('place', metadata,
@@ -110,7 +112,7 @@ class M2MTest(testbase.ORMTest):
for p in l:
pp = p.places
- self.echo("Place " + str(p) +" places " + repr(pp))
+ print "Place " + str(p) +" places " + repr(pp)
[sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7]
sess.flush()
@@ -176,7 +178,7 @@ class M2MTest(testbase.ORMTest):
self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])})
self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])})
-class M2MTest2(testbase.ORMTest):
+class M2MTest2(ORMTest):
def define_tables(self, metadata):
global studentTbl
studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True))
@@ -243,7 +245,7 @@ class M2MTest2(testbase.ORMTest):
sess.flush()
assert enrolTbl.count().scalar() == 0
-class M2MTest3(testbase.ORMTest):
+class M2MTest3(ORMTest):
def define_tables(self, metadata):
global c, c2a1, c2a2, b, a
c = Table('c', metadata,
@@ -277,15 +279,15 @@ class M2MTest3(testbase.ORMTest):
class A(object):pass
class B(object):pass
- assign_mapper(B, b)
+ mapper(B, b)
- assign_mapper(A, a,
+ mapper(A, a,
properties = {
'tbs' : relation(B, primaryjoin=and_(b.c.a1==a.c.a1, b.c.b2 == True), lazy=False),
}
)
- assign_mapper(C, c,
+ mapper(C, c,
properties = {
'a1s' : relation(A, secondary=c2a1, lazy=False),
'a2s' : relation(A, secondary=c2a2, lazy=False)
diff --git a/test/orm/mapper.py b/test/orm/mapper.py
index c0297e514..b72a10516 100644
--- a/test/orm/mapper.py
+++ b/test/orm/mapper.py
@@ -1,13 +1,14 @@
-from testbase import PersistTest, AssertMixin
+"""tests general mapper operations with an emphasis on selecting/loading"""
+
import testbase
-import unittest, sys, os
from sqlalchemy import *
+from sqlalchemy.orm import *
import sqlalchemy.exceptions as exceptions
-from sqlalchemy.ext.sessioncontext import SessionContext
-from tables import *
-import tables
+from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
-"""tests general mapper operations with an emphasis on selecting/loading"""
class MapperSuperTest(AssertMixin):
def setUpAll(self):
@@ -21,36 +22,6 @@ class MapperSuperTest(AssertMixin):
pass
class MapperTest(MapperSuperTest):
- # TODO: MapperTest has grown much larger than it originally was and needs
- # to be broken up among various functions, including querying, session operations,
- # mapper configurational issues
- def testget(self):
- s = create_session()
- mapper(User, users)
- self.assert_(s.get(User, 19) is None)
- u = s.get(User, 7)
- u2 = s.get(User, 7)
- self.assert_(u is u2)
- s.clear()
- u2 = s.get(User, 7)
- self.assert_(u is not u2)
-
- def testunicodeget(self):
- """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail
- on postgres, mysql and oracle unless it is converted to an encoded string"""
- metadata = MetaData(db)
- table = Table('foo', metadata,
- Column('id', Unicode(10), primary_key=True),
- Column('data', Unicode(40)))
- try:
- table.create()
- class LocalFoo(object):pass
- mapper(LocalFoo, table)
- crit = 'petit voix m\xe2\x80\x99a '.decode('utf-8')
- print repr(crit)
- create_session().query(LocalFoo).get(crit)
- finally:
- table.drop()
def testpropconflict(self):
"""test that a backref created against an existing mapper with a property name
@@ -76,25 +47,20 @@ class MapperTest(MapperSuperTest):
assert str(e) == "Invalid cascade option 'fake'"
def testcolumnprefix(self):
- mapper(User, users, column_prefix='_', properties={
- 'user_name':synonym('_user_name')
- })
+ mapper(User, users, column_prefix='_')
s = create_session()
u = s.get(User, 7)
assert u._user_name=='jack'
assert u._user_id ==7
assert not hasattr(u, 'user_name')
- u2 = s.query(User).filter_by(user_name='jack').one()
- assert u is u2
def testrefresh(self):
- mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
+ mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), backref='user')})
s = create_session()
u = s.get(User, 7)
u.user_name = 'foo'
a = Address()
- import sqlalchemy.orm.session
- assert sqlalchemy.orm.session.object_session(a) is None
+ assert object_session(a) is None
u.addresses.append(a)
self.assert_(a in u.addresses)
@@ -120,6 +86,7 @@ class MapperTest(MapperSuperTest):
# get the attribute, it refreshes
self.assert_(u.user_name == 'jack')
self.assert_(a not in u.addresses)
+
def testexpirecascade(self):
mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), cascade="all, refresh-expire")})
@@ -171,25 +138,28 @@ class MapperTest(MapperSuperTest):
def __init__(self):
raise ex
mapper(Foo, users)
-
+
try:
Foo()
assert False
except Exception, e:
assert e is ex
+ clear_mappers()
+ mapper(Foo, users, extension=SessionContextExt(SessionContext()))
def bad_expunge(foo):
raise Exception("this exception should be stated as a warning")
import warnings
warnings.filterwarnings("always", r".*this exception should be stated as a warning")
+
sess.expunge = bad_expunge
try:
Foo(_sa_session=sess)
assert False
except Exception, e:
assert e is ex
-
+
def testrefresh_lazy(self):
"""test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems"""
s = create_session()
@@ -198,7 +168,7 @@ class MapperTest(MapperSuperTest):
u = q2.selectfirst(users.c.user_id==8)
def go():
s.refresh(u)
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
def testexpire(self):
"""test the expire function"""
@@ -233,12 +203,19 @@ class MapperTest(MapperSuperTest):
def testrefresh2(self):
"""test a hang condition that was occuring on expire/refresh"""
+
s = create_session()
- mapper(Address, addresses)
-
- mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) )
+ m1 = mapper(Address, addresses)
+ m2 = mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) )
+ assert m1._Mapper__is_compiled is False
+ assert m2._Mapper__is_compiled is False
+
+# compile_mappers()
+ print "NEW USER"
u=User()
+ print "NEW USER DONE"
+ assert m2._Mapper__is_compiled is True
u.user_name='Justin'
a = Address()
a.address_id=17 # to work around the hardcoded IDs in this test suite....
@@ -259,15 +236,8 @@ class MapperTest(MapperSuperTest):
m = mapper(User, users, properties = {
'addresses' : relation(mapper(Address, addresses))
}).compile()
- self.assert_(User.addresses.property is m.props['addresses'])
+ self.assert_(User.addresses.property is m.get_property('addresses'))
- def testquery(self):
- """test a basic Query.select() operation."""
- mapper(User, users)
- l = create_session().query(User).select()
- self.assert_result(l, User, *user_result)
- l = create_session().query(User).select(users.c.user_name.endswith('ed'))
- self.assert_result(l, User, *user_result[1:3])
def testrecursiveselectby(self):
"""test that no endless loop occurs when traversing for select_by"""
@@ -302,151 +272,6 @@ class MapperTest(MapperSuperTest):
l = q.select()
self.assert_result(l, User, *result)
- def testwithparent(self):
- """test the with_parent()) method and one-to-many relationships"""
-
- m = mapper(User, users, properties={
- 'user_name_syn':synonym('user_name'),
- 'orders':relation(mapper(Order, orders, properties={
- 'items':relation(mapper(Item, orderitems)),
- 'items_syn':synonym('items')
- })),
- 'orders_syn':synonym('orders')
- })
-
- sess = create_session()
- q = sess.query(m)
- u1 = q.get_by(user_name='jack')
-
- # test auto-lookup of property
- o = sess.query(Order).with_parent(u1).list()
- self.assert_result(o, Order, *user_all_result[0]['orders'][1])
-
- # test with explicit property
- o = sess.query(Order).with_parent(u1, property='orders').list()
- self.assert_result(o, Order, *user_all_result[0]['orders'][1])
-
- # test static method
- o = Query.query_from_parent(u1, property='orders', session=sess).list()
- self.assert_result(o, Order, *user_all_result[0]['orders'][1])
-
- # test generative criterion
- o = sess.query(Order).with_parent(u1).select_by(orders.c.order_id>2)
- self.assert_result(o, Order, *user_all_result[0]['orders'][1][1:])
-
- try:
- q = sess.query(Item).with_parent(u1)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'"
-
-
- for nameprop, orderprop in (
- ('user_name', 'orders'),
- ('user_name_syn', 'orders'),
- ('user_name', 'orders_syn'),
- ('user_name_syn', 'orders_syn'),
- ):
- sess = create_session()
- q = sess.query(User)
-
- u1 = q.filter_by(**{nameprop:'jack'}).one()
-
- o = sess.query(Order).with_parent(u1, property=orderprop).list()
- self.assert_result(o, Order, *user_all_result[0]['orders'][1])
-
- def testwithparentm2m(self):
- """test the with_parent() method and many-to-many relationships"""
-
- m = mapper(Item, orderitems, properties = {
- 'keywords' : relation(mapper(Keyword, keywords), itemkeywords)
- })
- sess = create_session()
- i1 = sess.query(Item).get_by(item_id=2)
- k = sess.query(Keyword).with_parent(i1)
- self.assert_result(k, Keyword, *item_keyword_result[1]['keywords'][1])
-
-
- def test_join(self):
- """test functions derived from Query's _join_to function."""
-
- m = mapper(User, users, properties={
- 'orders':relation(mapper(Order, orders, properties={
- 'items':relation(mapper(Item, orderitems)),
- 'items_syn':synonym('items')
- })),
-
- 'orders_syn':synonym('orders'),
- })
-
- sess = create_session()
- q = sess.query(m)
-
- for j in (
- ['orders', 'items'],
- ['orders', 'items_syn'],
- ['orders_syn', 'items'],
- ['orders_syn', 'items_syn'],
- ):
- for q in (
- q.filter(orderitems.c.item_name=='item 4').join(j),
- q.filter(orderitems.c.item_name=='item 4').join(j[-1]),
- q.filter(orderitems.c.item_name=='item 4').filter(q.join_via(j)),
- q.filter(orderitems.c.item_name=='item 4').filter(q.join_to(j[-1])),
- ):
- l = q.all()
- self.assert_result(l, User, user_result[0])
-
- l = q.select_by(item_name='item 4')
- self.assert_result(l, User, user_result[0])
-
- l = q.filter(orderitems.c.item_name=='item 4').join('item_name').list()
- self.assert_result(l, User, user_result[0])
-
- l = q.filter(orderitems.c.item_name=='item 4').join('items').list()
- self.assert_result(l, User, user_result[0])
-
- # test comparing to an object instance
- item = sess.query(Item).get_by(item_name='item 4')
-
- l = sess.query(Order).select_by(items=item)
- self.assert_result(l, Order, user_all_result[0]['orders'][1][1])
-
- l = q.select_by(items=item)
- self.assert_result(l, User, user_result[0])
-
- # TODO: this works differently from:
- #q = sess.query(User).join(['orders', 'items']).select_by(order_id=3)
- # because select_by() doesnt respect query._joinpoint, whereas filter_by does
- q = sess.query(User).join(['orders', 'items']).filter_by(order_id=3).list()
- self.assert_result(l, User, user_result[0])
-
- try:
- # this should raise AttributeError
- l = q.select_by(items=5)
- assert False
- except AttributeError:
- assert True
-
- def testautojoinm2m(self):
- """test functions derived from Query's _join_to function."""
-
- m = mapper(Order, orders, properties = {
- 'items' : relation(mapper(Item, orderitems, properties = {
- 'keywords' : relation(mapper(Keyword, keywords), itemkeywords)
- }))
- })
-
- sess = create_session()
- q = sess.query(m)
-
- l = q.filter(keywords.c.name=='square').join(['items', 'keywords']).list()
- self.assert_result(l, Order, order_result[1])
-
- # test comparing to an object instance
- item = sess.query(Item).selectfirst()
- l = sess.query(Item).select_by(keywords=item.keywords[0])
- assert item == l[0]
def testcustomjoin(self):
"""test that the from_obj parameter to query.select() can be used
@@ -475,7 +300,7 @@ class MapperTest(MapperSuperTest):
# l = create_session().query(User).select(order_by=None)
- @testbase.unsupported('firebird')
+ @testing.unsupported('firebird')
def testfunction(self):
"""test mapping to a SELECT statement that has functions in it."""
s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')],
@@ -488,52 +313,7 @@ class MapperTest(MapperSuperTest):
assert l[0].concat == l[0].user_id * 2 == 14
assert l[1].concat == l[1].user_id * 2 == 16
- def testexternalcolumns(self):
- """test creating mappings that reference external columns or functions"""
-
- f = (users.c.user_id *2).label('concat')
- try:
- mapper(User, users, properties={
- 'concat': f,
- })
- class_mapper(User)
- except exceptions.ArgumentError, e:
- assert str(e) == "Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(f)
- clear_mappers()
-
- mapper(User, users, properties={
- 'concat': column_property(f),
- 'count': column_property(select([func.count(addresses.c.address_id)], users.c.user_id==addresses.c.user_id, scalar=True).label('count'))
- })
-
- sess = create_session()
- l = sess.query(User).select()
- for u in l:
- print "User", u.user_id, u.user_name, u.concat, u.count
- assert l[0].concat == l[0].user_id * 2 == 14
- assert l[1].concat == l[1].user_id * 2 == 16
-
- ### eager loads, not really working across all DBs, no column aliasing in place so
- # results still wont be good for larger situations
- clear_mappers()
- mapper(Address, addresses, properties={
- 'user':relation(User, lazy=False)
- })
-
- mapper(User, users, properties={
- 'concat': column_property(f),
- })
-
- for x in range(0, 2):
- sess.clear()
- l = sess.query(Address).select()
- for a in l:
- print "User", a.user.user_id, a.user.user_name, a.user.concat
- assert l[0].user.concat == l[0].user.user_id * 2 == 14
- assert l[1].user.concat == l[1].user.user_id * 2 == 16
-
-
- @testbase.unsupported('firebird')
+ @testing.unsupported('firebird')
def testcount(self):
"""test the count function on Query.
@@ -610,12 +390,12 @@ class MapperTest(MapperSuperTest):
def go():
u = sess.query(User).options(eagerload('adlist')).get_by(user_name='jack')
self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
def testextensionoptions(self):
sess = create_session()
class ext1(MapperExtension):
- def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
+ def populate_instance(self, mapper, selectcontext, row, instance, **flags):
"""test options at the Mapper._instance level"""
instance.TEST = "hello world"
return EXT_PASS
@@ -626,7 +406,7 @@ class MapperTest(MapperSuperTest):
def select_by(self, *args, **kwargs):
"""test options at the Query level"""
return "HI"
- def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
+ def populate_instance(self, mapper, selectcontext, row, instance, **flags):
"""test options at the Mapper._instance level"""
instance.TEST_2 = "also hello world"
return EXT_PASS
@@ -649,7 +429,7 @@ class MapperTest(MapperSuperTest):
def go():
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testeageroptionswithlimit(self):
sess = create_session()
@@ -661,7 +441,7 @@ class MapperTest(MapperSuperTest):
def go():
assert u.user_id == 8
assert len(u.addresses) == 3
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
sess.clear()
@@ -670,7 +450,7 @@ class MapperTest(MapperSuperTest):
u = sess.query(User).get_by(user_id=8)
assert u.user_id == 8
assert len(u.addresses) == 3
- assert "tbl_row_count" not in self.capture_sql(db, go)
+ assert "tbl_row_count" not in self.capture_sql(testbase.db, go)
def testlazyoptionswithlimit(self):
sess = create_session()
@@ -682,7 +462,7 @@ class MapperTest(MapperSuperTest):
def go():
assert u.user_id == 8
assert len(u.addresses) == 3
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
def testeagerdegrade(self):
"""tests that an eager relation automatically degrades to a lazy relation if eager columns are not available"""
@@ -695,7 +475,7 @@ class MapperTest(MapperSuperTest):
def go():
l = sess.query(usermapper).select()
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
sess.clear()
@@ -706,7 +486,7 @@ class MapperTest(MapperSuperTest):
r = users.select().execute()
l = usermapper.instances(r, sess)
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 4)
+ self.assert_sql_count(testbase.db, go, 4)
clear_mappers()
@@ -733,7 +513,7 @@ class MapperTest(MapperSuperTest):
def go():
l = sess.query(usermapper).select()
self.assert_result(l, User, *user_all_result)
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
sess.clear()
@@ -743,7 +523,7 @@ class MapperTest(MapperSuperTest):
r = users.select().execute()
l = usermapper.instances(r, sess)
self.assert_result(l, User, *user_all_result)
- self.assert_sql_count(db, go, 7)
+ self.assert_sql_count(testbase.db, go, 7)
def testlazyoptions(self):
@@ -755,7 +535,7 @@ class MapperTest(MapperSuperTest):
l = sess.query(User).options(lazyload('addresses')).select()
def go():
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 3)
+ self.assert_sql_count(testbase.db, go, 3)
def testlatecompile(self):
"""tests mappers compiling late in the game"""
@@ -769,7 +549,7 @@ class MapperTest(MapperSuperTest):
u = sess.query(User).select()
def go():
print u[0].orders[1].items[0].keywords[1]
- self.assert_sql_count(db, go, 3)
+ self.assert_sql_count(testbase.db, go, 3)
def testdeepoptions(self):
mapper(User, users,
@@ -787,18 +567,18 @@ class MapperTest(MapperSuperTest):
u = sess.query(User).select()
def go():
print u[0].orders[1].items[0].keywords[1]
- self.assert_sql_count(db, go, 3)
+ self.assert_sql_count(testbase.db, go, 3)
sess.clear()
print "-------MARK----------"
- # eagerload orders, orders.items, orders.items.keywords
- q2 = sess.query(User).options(eagerload('orders'), eagerload('orders.items'), eagerload('orders.items.keywords'))
+ # eagerload orders.items.keywords; eagerload_all() implies eager load of orders, orders.items
+ q2 = sess.query(User).options(eagerload_all('orders.items.keywords'))
u = q2.select()
def go():
print u[0].orders[1].items[0].keywords[1]
print "-------MARK2----------"
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
sess.clear()
@@ -808,7 +588,7 @@ class MapperTest(MapperSuperTest):
def go():
print u[0].orders[1].items[0].keywords[1]
print "-------MARK3----------"
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
print "-------MARK4----------"
sess.clear()
@@ -818,7 +598,7 @@ class MapperTest(MapperSuperTest):
print "-------MARK5----------"
q3 = sess.query(User).options(eagerload('orders.items.keywords'))
u = q3.select()
- self.assert_sql_count(db, go, 2)
+ self.assert_sql_count(testbase.db, go, 2)
class DeferredTest(MapperSuperTest):
@@ -839,8 +619,8 @@ class DeferredTest(MapperSuperTest):
o2 = l[2]
print o2.description
- orderby = str(orders.default_order_by()[0].compile(engine=db))
- self.assert_sql(db, go, [
+ orderby = str(orders.default_order_by()[0].compile(bind=testbase.db))
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
@@ -893,7 +673,8 @@ class DeferredTest(MapperSuperTest):
'description':deferred(orders.c.description, group='primary'),
'opened':deferred(orders.c.isopen, group='primary')
})
- q = create_session().query(m)
+ sess = create_session()
+ q = sess.query(m)
def go():
l = q.select()
o2 = l[2]
@@ -901,12 +682,43 @@ class DeferredTest(MapperSuperTest):
assert o2.opened == 1
assert o2.userident == 7
assert o2.description == 'order 3'
- orderby = str(orders.default_order_by()[0].compile(db))
- self.assert_sql(db, go, [
+ orderby = str(orders.default_order_by()[0].compile(testbase.db))
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
+ o2 = q.select()[2]
+# assert o2.opened == 1
+ assert o2.description == 'order 3'
+ assert o2 not in sess.dirty
+ o2.description = 'order 3'
+ def go():
+ sess.flush()
+ self.assert_sql_count(testbase.db, go, 0)
+
+ def testcommitsstate(self):
+ """test that when deferred elements are loaded via a group, they get the proper CommittedState
+ and dont result in changes being committed"""
+
+ m = mapper(Order, orders, properties = {
+ 'userident':deferred(orders.c.user_id, group='primary'),
+ 'description':deferred(orders.c.description, group='primary'),
+ 'opened':deferred(orders.c.isopen, group='primary')
+ })
+ sess = create_session()
+ q = sess.query(m)
+ o2 = q.select()[2]
+ # this will load the group of attributes
+ assert o2.description == 'order 3'
+ assert o2 not in sess.dirty
+ # this will mark it as 'dirty', but nothing actually changed
+ o2.description = 'order 3'
+ def go():
+ # therefore the flush() shouldnt actually issue any SQL
+ sess.flush()
+ self.assert_sql_count(testbase.db, go, 0)
+
def testoptions(self):
"""tests using options on a mapper to create deferred and undeferred columns"""
m = mapper(Order, orders)
@@ -917,8 +729,8 @@ class DeferredTest(MapperSuperTest):
l = q2.select()
print l[2].user_id
- orderby = str(orders.default_order_by()[0].compile(db))
- self.assert_sql(db, go, [
+ orderby = str(orders.default_order_by()[0].compile(testbase.db))
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
@@ -927,10 +739,31 @@ class DeferredTest(MapperSuperTest):
def go():
l = q3.select()
print l[3].user_id
- self.assert_sql(db, go, [
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
])
+ def testundefergroup(self):
+ """tests undefer_group()"""
+ m = mapper(Order, orders, properties = {
+ 'userident':deferred(orders.c.user_id, group='primary'),
+ 'description':deferred(orders.c.description, group='primary'),
+ 'opened':deferred(orders.c.isopen, group='primary')
+ })
+ sess = create_session()
+ q = sess.query(m)
+ def go():
+ l = q.options(undefer_group('primary')).select()
+ o2 = l[2]
+ print o2.opened, o2.description, o2.userident
+ assert o2.opened == 1
+ assert o2.userident == 7
+ assert o2.description == 'order 3'
+ orderby = str(orders.default_order_by()[0].compile(testbase.db))
+ self.assert_sql(testbase.db, go, [
+ ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen, orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
+ ])
+
def testdeepoptions(self):
m = mapper(User, users, properties={
@@ -946,7 +779,7 @@ class DeferredTest(MapperSuperTest):
item = l[0].orders[1].items[1]
def go():
print item.item_name
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
self.assert_(item.item_name == 'item 4')
sess.clear()
q2 = q.options(undefer('orders.items.item_name'))
@@ -954,10 +787,138 @@ class DeferredTest(MapperSuperTest):
item = l[0].orders[1].items[1]
def go():
print item.item_name
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
self.assert_(item.item_name == 'item 4')
-
+class CompositeTypesTest(ORMTest):
+ def define_tables(self, metadata):
+ global graphs, edges
+ graphs = Table('graphs', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('version_id', Integer, primary_key=True),
+ Column('name', String(30)))
+
+ edges = Table('edges', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('graph_id', Integer, nullable=False),
+ Column('graph_version_id', Integer, nullable=False),
+ Column('x1', Integer),
+ Column('y1', Integer),
+ Column('x2', Integer),
+ Column('y2', Integer),
+ ForeignKeyConstraint(['graph_id', 'graph_version_id'], ['graphs.id', 'graphs.version_id'])
+ )
+
+ def test_basic(self):
+ class Point(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+ def __colset__(self):
+ return [self.x, self.y]
+ def __eq__(self, other):
+ return other.x == self.x and other.y == self.y
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ class Graph(object):
+ pass
+ class Edge(object):
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ mapper(Graph, graphs, properties={
+ 'edges':relation(Edge)
+ })
+ mapper(Edge, edges, properties={
+ 'start':composite(Point, edges.c.x1, edges.c.y1),
+ 'end':composite(Point, edges.c.x2, edges.c.y2)
+ })
+
+ sess = create_session()
+ g = Graph()
+ g.id = 1
+ g.version_id=1
+ g.edges.append(Edge(Point(3, 4), Point(5, 6)))
+ g.edges.append(Edge(Point(14, 5), Point(2, 7)))
+ sess.save(g)
+ sess.flush()
+
+ sess.clear()
+ g2 = sess.query(Graph).get([g.id, g.version_id])
+ for e1, e2 in zip(g.edges, g2.edges):
+ assert e1.start == e2.start
+ assert e1.end == e2.end
+
+ g2.edges[1].end = Point(18, 4)
+ sess.flush()
+ sess.clear()
+ e = sess.query(Edge).get(g2.edges[1].id)
+ assert e.end == Point(18, 4)
+
+ e.end.x = 19
+ e.end.y = 5
+ sess.flush()
+ sess.clear()
+ assert sess.query(Edge).get(g2.edges[1].id).end == Point(19, 5)
+
+ g.edges[1].end = Point(19, 5)
+
+ sess.clear()
+ def go():
+ g2 = sess.query(Graph).options(eagerload('edges')).get([g.id, g.version_id])
+ for e1, e2 in zip(g.edges, g2.edges):
+ assert e1.start == e2.start
+ assert e1.end == e2.end
+ self.assert_sql_count(testbase.db, go, 1)
+
+ # test comparison of CompositeProperties to their object instances
+ g = sess.query(Graph).get([1, 1])
+ assert sess.query(Edge).filter(Edge.start==Point(3, 4)).one() is g.edges[0]
+
+ assert sess.query(Edge).filter(Edge.start!=Point(3, 4)).first() is g.edges[1]
+
+ assert sess.query(Edge).filter(Edge.start==None).all() == []
+
+
+ def test_pk(self):
+ """test using a composite type as a primary key"""
+
+ class Version(object):
+ def __init__(self, id, version):
+ self.id = id
+ self.version = version
+ def __colset__(self):
+ return [self.id, self.version]
+ def __eq__(self, other):
+ return other.id == self.id and other.version == self.version
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ class Graph(object):
+ def __init__(self, version):
+ self.version = version
+
+ mapper(Graph, graphs, properties={
+ 'version':composite(Version, graphs.c.id, graphs.c.version_id)
+ })
+
+ sess = create_session()
+ g = Graph(Version(1, 1))
+ sess.save(g)
+ sess.flush()
+
+ sess.clear()
+ g2 = sess.query(Graph).get([1, 1])
+ assert g.version == g2.version
+ sess.clear()
+
+ g2 = sess.query(Graph).get(Version(1, 1))
+ assert g.version == g2.version
+
+
+
class NoLoadTest(MapperSuperTest):
def testbasic(self):
"""tests a basic one-to-many lazy load"""
@@ -975,7 +936,6 @@ class NoLoadTest(MapperSuperTest):
self.assert_result(l[0], User,
{'user_id' : 7, 'addresses' : (Address, [])},
)
-
def testoptions(self):
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy=None)
@@ -992,8 +952,20 @@ class NoLoadTest(MapperSuperTest):
{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
)
-
-
+class MapperExtensionTest(MapperSuperTest):
+ def testcreateinstance(self):
+ class Ext(MapperExtension):
+ def create_instance(self, *args, **kwargs):
+ return User()
+ m = mapper(Address, addresses)
+ m = mapper(User, users, extension=Ext(), properties = dict(
+ addresses = relation(Address, lazy=True),
+ ))
+
+ q = create_session().query(m)
+ l = q.select();
+ self.assert_result(l, User, *user_address_result)
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/memusage.py b/test/orm/memusage.py
index 4e961a6d7..26da7c010 100644
--- a/test/orm/memusage.py
+++ b/test/orm/memusage.py
@@ -1,21 +1,18 @@
-from sqlalchemy import *
-from sqlalchemy.orm import mapperlib, session, unitofwork, attributes
-Mapper = mapperlib.Mapper
-import gc
import testbase
-import tables
+import gc
+from sqlalchemy import MetaData, Integer, String, ForeignKey
+from sqlalchemy.orm import mapper, relation, clear_mappers, create_session
+from sqlalchemy.orm.mapper import Mapper
+from testlib import *
class A(object):pass
class B(object):pass
-class MapperCleanoutTest(testbase.AssertMixin):
+class MapperCleanoutTest(AssertMixin):
"""test that clear_mappers() removes everything related to the class.
does not include classes that use the assignmapper extension."""
- def setUp(self):
- global engine
- engine = testbase.db
-
+
def test_mapper_cleanup(self):
for x in range(0, 5):
self.do_test()
@@ -33,7 +30,7 @@ class MapperCleanoutTest(testbase.AssertMixin):
assert True
def do_test(self):
- metadata = MetaData(engine)
+ metadata = MetaData(testbase.db)
table1 = Table("mytable", metadata,
Column('col1', Integer, primary_key=True),
diff --git a/test/orm/merge.py b/test/orm/merge.py
index cca01f2a5..3dd0a95a4 100644
--- a/test/orm/merge.py
+++ b/test/orm/merge.py
@@ -1,8 +1,9 @@
-from testbase import PersistTest, AssertMixin
import testbase
from sqlalchemy import *
-from tables import *
-import tables
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
class MergeTest(AssertMixin):
"""tests session.merge() functionality"""
@@ -164,5 +165,3 @@ class MergeTest(AssertMixin):
if __name__ == "__main__":
testbase.main()
-
- \ No newline at end of file
diff --git a/test/orm/onetoone.py b/test/orm/onetoone.py
index 6ac7c514d..e41fa1d20 100644
--- a/test/orm/onetoone.py
+++ b/test/orm/onetoone.py
@@ -1,6 +1,8 @@
import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
+from testlib import *
class Jack(object):
def __repr__(self):
@@ -22,7 +24,7 @@ class Port(object):
self.name=name
self.description = description
-class O2OTest(testbase.AssertMixin):
+class O2OTest(AssertMixin):
def setUpAll(self):
global jack, port, metadata, ctx
metadata = MetaData(testbase.db)
diff --git a/test/orm/query.py b/test/orm/query.py
index 872d1772e..3783e1fa0 100644
--- a/test/orm/query.py
+++ b/test/orm/query.py
@@ -1,43 +1,12 @@
import testbase
+import operator
from sqlalchemy import *
+from sqlalchemy import ansisql
from sqlalchemy.orm import *
+from testlib import *
from fixtures import *
-class Base(object):
- def __init__(self, **kwargs):
- for k in kwargs:
- setattr(self, k, kwargs[k])
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __eq__(self, other):
- """'passively' compare this object to another.
-
- only look at attributes that are present on the source object.
-
- """
- # use __dict__ to avoid instrumented properties
- for attr in self.__dict__.keys():
- if attr[0] == '_':
- continue
- value = getattr(self, attr)
- if hasattr(value, '__iter__') and not isinstance(value, basestring):
- if len(value) == 0:
- continue
- for (us, them) in zip(value, getattr(other, attr)):
- if us != them:
- return False
- else:
- continue
- else:
- if value is not None:
- if value != getattr(other, attr):
- return False
- else:
- return True
-
-class QueryTest(testbase.ORMTest):
+class QueryTest(ORMTest):
keep_mappers = True
keep_data = True
@@ -53,16 +22,16 @@ class QueryTest(testbase.ORMTest):
def define_tables(self, meta):
# a slight dirty trick here.
meta.tables = metadata.tables
- metadata.connect(meta.engine)
+ metadata.connect(meta.bind)
def setup_mappers(self):
mapper(User, users, properties={
- 'addresses':relation(Address),
+ 'addresses':relation(Address, backref='user'),
'orders':relation(Order, backref='user'), # o2m, m2o
})
mapper(Address, addresses)
mapper(Order, orders, properties={
- 'items':relation(Item, secondary=order_items), #m2m
+ 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m
'address':relation(Address), # m2o
})
mapper(Item, items, properties={
@@ -70,7 +39,6 @@ class QueryTest(testbase.ORMTest):
})
mapper(Keyword, keywords)
-
class GetTest(QueryTest):
def test_get(self):
s = create_session()
@@ -82,6 +50,33 @@ class GetTest(QueryTest):
u2 = s.query(User).get(7)
assert u is not u2
+ def test_load(self):
+ s = create_session()
+
+ try:
+ assert s.query(User).load(19) is None
+ assert False
+ except exceptions.InvalidRequestError:
+ assert True
+
+ u = s.query(User).load(7)
+ u2 = s.query(User).load(7)
+ assert u is u2
+ s.clear()
+ u2 = s.query(User).load(7)
+ assert u is not u2
+
+ u2.name = 'some name'
+ a = Address(name='some other name')
+ u2.addresses.append(a)
+ assert u2 in s.dirty
+ assert a in u2.addresses
+
+ s.query(User).load(7)
+ assert u2 not in s.dirty
+ assert u2.name =='jack'
+ assert a not in u2.addresses
+
def test_unicode(self):
"""test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail
on postgres, mysql and oracle unless it is converted to an encoded string"""
@@ -90,16 +85,116 @@ class GetTest(QueryTest):
Column('id', Unicode(40), primary_key=True),
Column('data', Unicode(40)))
table.create()
- ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8')
+ ustring = 'petit voix m\xe2\x80\x99a '.decode('utf-8')
table.insert().execute(id=ustring, data=ustring)
class LocalFoo(Base):pass
mapper(LocalFoo, table)
assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring)
+ def test_populate_existing(self):
+ s = create_session()
+
+ userlist = s.query(User).all()
+
+ u = userlist[0]
+ u.name = 'foo'
+ a = Address(name='ed')
+ u.addresses.append(a)
+
+ self.assert_(a in u.addresses)
+
+ s.query(User).populate_existing().all()
+
+ self.assert_(u not in s.dirty)
+
+ self.assert_(u.name == 'jack')
+
+ self.assert_(a not in u.addresses)
+
+ u.addresses[0].email_address = 'lala'
+ u.orders[1].items[2].description = 'item 12'
+ # test that lazy load doesnt change child items
+ s.query(User).populate_existing().all()
+ assert u.addresses[0].email_address == 'lala'
+ assert u.orders[1].items[2].description == 'item 12'
+
+ # eager load does
+ s.query(User).options(eagerload('addresses'), eagerload_all('orders.items')).populate_existing().all()
+ assert u.addresses[0].email_address == 'jack@bean.com'
+ assert u.orders[1].items[2].description == 'item 5'
+
+class OperatorTest(QueryTest):
+ """test sql.Comparator implementation for MapperProperties"""
+
+ def _test(self, clause, expected):
+ c = str(clause.compile(dialect=ansisql.ANSIDialect()))
+ assert c == expected, "%s != %s" % (c, expected)
+
+ def test_arithmetic(self):
+ create_session().query(User)
+ for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
+ (operator.sub, '-'), (operator.div, '/'),
+ ):
+ for (lhs, rhs, res) in (
+ (5, User.id, ':users_id %s users.id'),
+ (5, literal(6), ':literal %s :literal_1'),
+ (User.id, 5, 'users.id %s :users_id'),
+ (User.id, literal('b'), 'users.id %s :literal'),
+ (User.id, User.id, 'users.id %s users.id'),
+ (literal(5), 'b', ':literal %s :literal_1'),
+ (literal(5), User.id, ':literal %s users.id'),
+ (literal(5), literal(6), ':literal %s :literal_1'),
+ ):
+ self._test(py_op(lhs, rhs), res % sql_op)
+
+ def test_comparison(self):
+ create_session().query(User)
+ for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'),
+ (operator.gt, '>', '<'),
+ (operator.eq, '=', '='),
+ (operator.ne, '!=', '!='),
+ (operator.le, '<=', '>='),
+ (operator.ge, '>=', '<=')):
+ for (lhs, rhs, l_sql, r_sql) in (
+ ('a', User.id, ':users_id', 'users.id'),
+ ('a', literal('b'), ':literal_1', ':literal'), # note swap!
+ (User.id, 'b', 'users.id', ':users_id'),
+ (User.id, literal('b'), 'users.id', ':literal'),
+ (User.id, User.id, 'users.id', 'users.id'),
+ (literal('a'), 'b', ':literal', ':literal_1'),
+ (literal('a'), User.id, ':literal', 'users.id'),
+ (literal('a'), literal('b'), ':literal', ':literal_1'),
+ ):
+
+ # the compiled clause should match either (e.g.):
+ # 'a' < 'b' -or- 'b' > 'a'.
+ compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect()))
+ fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
+ rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
+
+ self.assert_(compiled == fwd_sql or compiled == rev_sql,
+ "\n'" + compiled + "'\n does not match\n'" +
+ fwd_sql + "'\n or\n'" + rev_sql + "'")
+
+ def test_in(self):
+ self._test(User.id.in_('a', 'b'), "users.id IN (:users_id, :users_id_1)")
+
+ def test_clauses(self):
+ for (expr, compare) in (
+ (func.max(User.id), "max(users.id)"),
+ (desc(User.id), "users.id DESC"),
+ (between(5, User.id, Address.id), ":literal BETWEEN users.id AND addresses.id"),
+ # this one would require adding compile() to InstrumentedScalarAttribute. do we want this ?
+ #(User.id, "users.id")
+ ):
+ c = expr.compile(dialect=ansisql.ANSIDialect())
+ assert str(c) == compare, "%s != %s" % (str(c), compare)
+
+
class CompileTest(QueryTest):
def test_deferred(self):
session = create_session()
- s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.id)).compile()
+ s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile()
l = session.query(User).instances(s.execute(emailad = 'jack@bean.com'))
assert [User(id=7)] == l
@@ -108,7 +203,23 @@ class SliceTest(QueryTest):
def test_first(self):
assert User(id=7) == create_session().query(User).first()
- assert create_session().query(User).filter(users.c.id==27).first() is None
+ assert create_session().query(User).filter(User.id==27).first() is None
+
+ # more slice tests are available in test/orm/generative.py
+
+class TextTest(QueryTest):
+ def test_fulltext(self):
+ assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).from_statement("select * from users").all()
+
+ def test_fragment(self):
+ assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (8, 9)").all()
+
+ assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all()
+
+ assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all()
+
+ def test_binds(self):
+ assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all()
class FilterTest(QueryTest):
def test_basic(self):
@@ -122,8 +233,75 @@ class FilterTest(QueryTest):
assert User(id=8) == create_session().query(User)[1]
def test_onefilter(self):
- assert [User(id=8), User(id=9)] == create_session().query(User).filter(users.c.name.endswith('ed')).all()
+ assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all()
+
+ def test_contains(self):
+ """test comparing a collection to an object instance."""
+
+ sess = create_session()
+ address = sess.query(Address).get(3)
+ assert [User(id=8)] == sess.query(User).filter(User.addresses.contains(address)).all()
+ try:
+ sess.query(User).filter(User.addresses == address)
+ assert False
+ except exceptions.InvalidRequestError:
+ assert True
+
+ assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
+
+ try:
+ assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+ assert False
+ except exceptions.InvalidRequestError:
+ assert True
+
+ #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+
+ def test_any(self):
+ sess = create_session()
+
+ assert [User(id=8), User(id=9)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).all()
+
+ assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'), id=4)).all()
+
+ assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all()
+
+ def test_has(self):
+ sess = create_session()
+ assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all()
+
+ assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all()
+
+ assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+
+ def test_contains_m2m(self):
+ sess = create_session()
+ item = sess.query(Item).get(3)
+ assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).all()
+
+ assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all()
+
+ def test_comparison(self):
+ """test scalar comparison to an object instance"""
+
+ sess = create_session()
+ user = sess.query(User).get(8)
+ assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user==user).all()
+
+ assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all()
+
+class AggregateTest(QueryTest):
+ def test_sum(self):
+ sess = create_session()
+ orders = sess.query(Order).filter(Order.id.in_(2, 3, 4))
+ assert orders.sum(Order.user_id * Order.address_id) == 79
+
+ def test_apply(self):
+ sess = create_session()
+ assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_(2, 3, 4)).one() == 79
+
+
class CountTest(QueryTest):
def test_basic(self):
assert 4 == create_session().query(User).count()
@@ -139,7 +317,7 @@ class TextTest(QueryTest):
assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all()
- assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(users.c.id==9).all()
+ assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all()
def test_binds(self):
assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all()
@@ -188,14 +366,25 @@ class ParentTest(QueryTest):
class JoinTest(QueryTest):
+
def test_overlapping_paths(self):
- # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
- result = create_session().query(User).join(['orders', 'items']).filter_by(id=3).reset_joinpoint().join(['orders','address']).filter_by(id=1).all()
- assert [User(id=7, name='jack')] == result
+ for aliased in (True,False):
+ # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
+ result = create_session().query(User).join(['orders', 'items'], aliased=aliased).filter_by(id=3).join(['orders','address'], aliased=aliased).filter_by(id=1).all()
+ assert [User(id=7, name='jack')] == result
def test_overlapping_paths_outerjoin(self):
- result = create_session().query(User).outerjoin(['orders', 'items']).filter_by(id=3).reset_joinpoint().outerjoin(['orders','address']).filter_by(id=1).all()
+ result = create_session().query(User).outerjoin(['orders', 'items']).filter_by(id=3).outerjoin(['orders','address']).filter_by(id=1).all()
assert [User(id=7, name='jack')] == result
+
+ def test_reset_joinpoint(self):
+ for aliased in (True, False):
+ # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
+ result = create_session().query(User).join(['orders', 'items'], aliased=aliased).filter_by(id=3).reset_joinpoint().join(['orders','address'], aliased=aliased).filter_by(id=1).all()
+ assert [User(id=7, name='jack')] == result
+
+ result = create_session().query(User).outerjoin(['orders', 'items'], aliased=aliased).filter_by(id=3).reset_joinpoint().outerjoin(['orders','address'], aliased=aliased).filter_by(id=1).all()
+ assert [User(id=7, name='jack')] == result
def test_overlap_with_aliases(self):
oalias = orders.alias('oalias')
@@ -206,7 +395,64 @@ class JoinTest(QueryTest):
result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).filter_by(id=4).all()
assert [User(id=7, name='jack')] == result
-class MultiplePathTest(testbase.ORMTest):
+ def test_aliased(self):
+ """test automatic generation of aliased joins."""
+
+ sess = create_session()
+
+ # test a basic aliasized path
+ q = sess.query(User).join('addresses', aliased=True).filter_by(email_address='jack@bean.com')
+ assert [User(id=7)] == q.all()
+
+ q = sess.query(User).join('addresses', aliased=True).filter(Address.email_address=='jack@bean.com')
+ assert [User(id=7)] == q.all()
+
+ # test two aliasized paths, one to 'orders' and the other to 'orders','items'.
+ # one row is returned because user 7 has order 3 and also has order 1 which has item 1
+ # this tests a o2m join and a m2m join.
+ q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join(['orders', 'items'], aliased=True).filter(Item.description=="item 1")
+ assert q.count() == 1
+ assert [User(id=7)] == q.all()
+
+ # test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1
+ # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the
+ # join clauses that are cached by the relationship.
+ q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Order.description=="item 1")
+ assert [] == q.all()
+ assert q.count() == 0
+
+ q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4'))
+ assert [User(id=7)] == q.all()
+
+ def test_aliased_add_entity(self):
+ """test the usage of aliased joins with add_entity()"""
+ sess = create_session()
+ q = sess.query(User).join('orders', aliased=True, id='order1').filter(Order.description=="order 3").join(['orders', 'items'], aliased=True, id='item1').filter(Item.description=="item 1")
+
+ try:
+ q.add_entity(Order, id='fakeid').compile()
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Query has no alias identified by 'fakeid'"
+
+ try:
+ q.add_entity(Order, id='fakeid').instances(None)
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Query has no alias identified by 'fakeid'"
+
+ q = q.add_entity(Order, id='order1').add_entity(Item, id='item1')
+ assert q.count() == 1
+ assert [(User(id=7), Order(description='order 3'), Item(description='item 1'))] == q.all()
+
+ q = sess.query(User).add_entity(Order).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=='order 4')
+ try:
+ q.compile()
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Ambiguous join for entity 'Mapper|Order|orders'; specify id=<someid> to query.join()/query.add_entity()"
+
+class MultiplePathTest(ORMTest):
def define_tables(self, metadata):
global t1, t2, t1t2_1, t1t2_2
t1 = Table('t1', metadata,
@@ -217,7 +463,7 @@ class MultiplePathTest(testbase.ORMTest):
Column('id', Integer, primary_key=True),
Column('data', String(30))
)
-
+
t1t2_1 = Table('t1t2_1', metadata,
Column('t1id', Integer, ForeignKey('t1.id')),
Column('t2id', Integer, ForeignKey('t2.id'))
@@ -227,23 +473,28 @@ class MultiplePathTest(testbase.ORMTest):
Column('t1id', Integer, ForeignKey('t1.id')),
Column('t2id', Integer, ForeignKey('t2.id'))
)
-
+
def test_basic(self):
class T1(object):pass
class T2(object):pass
-
+
mapper(T1, t1, properties={
't2s_1':relation(T2, secondary=t1t2_1),
't2s_2':relation(T2, secondary=t1t2_2),
})
mapper(T2, t2)
-
+
try:
- create_session().query(T1).join('t2s_1').filter_by(t2.c.id==5).reset_joinpoint().join('t2s_2')
+ create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2')
assert False
except exceptions.InvalidRequestError, e:
- assert str(e) == "Can't join to property 't2s_2'; a path to this table along a different secondary table already exists. Use explicit `Alias` objects."
+ assert str(e) == "Can't join to property 't2s_2'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`."
+
+ create_session().query(T1).join('t2s_1', aliased=True).filter(t2.c.id==5).reset_joinpoint().join('t2s_2').all()
+ create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2', aliased=True).all()
+
+
class SynonymTest(QueryTest):
keep_mappers = True
keep_data = True
@@ -372,22 +623,48 @@ class InstancesTest(QueryTest):
l = q.instances(selectquery.execute(), Address)
assert l == expected
+ for aliased in (False, True):
+ q = sess.query(User)
+ q = q.add_entity(Address).outerjoin('addresses', aliased=aliased)
+ l = q.all()
+ assert l == expected
+
+ q = sess.query(User).add_entity(Address)
+ l = q.join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com').all()
+ assert l == [(user8, address3)]
+
+ q = sess.query(User, Address).join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com')
+ assert q.all() == [(user8, address3)]
+
+ q = sess.query(User, Address).join('addresses', aliased=aliased).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
+ assert q.all() == [(user8, address3)]
+
+ def test_aliased_multi_mappers(self):
+ sess = create_session()
+
+ (user7, user8, user9, user10) = sess.query(User).all()
+ (address1, address2, address3, address4, address5) = sess.query(Address).all()
+
+ # note the result is a cartesian product
+ expected = [(user7, address1),
+ (user8, address2),
+ (user8, address3),
+ (user8, address4),
+ (user9, address5),
+ (user10, None)]
+
q = sess.query(User)
- q = q.add_entity(Address).outerjoin('addresses')
+ adalias = addresses.alias('adalias')
+ q = q.add_entity(Address, alias=adalias).select_from(users.outerjoin(adalias))
l = q.all()
assert l == expected
- q = sess.query(User).add_entity(Address)
- l = q.join('addresses').filter_by(email_address='ed@bettyboop.com').all()
+ q = sess.query(User).add_entity(Address, alias=adalias)
+ l = q.select_from(users.outerjoin(adalias)).filter(adalias.c.email_address=='ed@bettyboop.com').all()
assert l == [(user8, address3)]
- q = sess.query(User, Address).join('addresses').filter_by(email_address='ed@bettyboop.com')
- assert q.all() == [(user8, address3)]
-
- q = sess.query(User, Address).join('addresses').options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
- assert q.all() == [(user8, address3)]
-
def test_multi_columns(self):
+ """test aliased/nonalised joins with the usage of add_column()"""
sess = create_session()
(user7, user8, user9, user10) = sess.query(User).all()
expected = [(user7, 1),
@@ -395,18 +672,18 @@ class InstancesTest(QueryTest):
(user9, 1),
(user10, 0)
]
-
- q = sess.query(User)
- q = q.group_by([c for c in users.c]).order_by(User.c.id).outerjoin('addresses').add_column(func.count(addresses.c.id).label('count'))
- l = q.all()
- assert l == expected
+
+ for aliased in (False, True):
+ q = sess.query(User)
+ q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count'))
+ l = q.all()
+ assert l == expected
- s = select([users, func.count(addresses.c.id).label('count')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=users.c.id)
+ s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
q = sess.query(User)
l = q.add_column("count").from_statement(s).all()
assert l == expected
- @testbase.unsupported('mysql') # only because of "+" operator requiring "concat" in mysql (fix #475)
def test_two_columns(self):
sess = create_session()
(user7, user8, user9, user10) = sess.query(User).all()
@@ -416,17 +693,162 @@ class InstancesTest(QueryTest):
(user9, 1, "Name:fred"),
(user10, 0, "Name:chuck")]
+ # test with a straight statement
s = select([users, func.count(addresses.c.id).label('count'), ("Name:" + users.c.name).label('concat')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.id])
q = create_session().query(User)
l = q.add_column("count").add_column("concat").from_statement(s).all()
assert l == expected
+ # test with select_from()
q = create_session().query(User).add_column(func.count(addresses.c.id))\
.add_column(("Name:" + users.c.name)).select_from(users.outerjoin(addresses))\
.group_by([c for c in users.c]).order_by(users.c.id)
assert q.all() == expected
+ # test with outerjoin() both aliased and non
+ for aliased in (False, True):
+ q = create_session().query(User).add_column(func.count(addresses.c.id))\
+ .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=aliased)\
+ .group_by([c for c in users.c]).order_by(users.c.id)
+
+ assert q.all() == expected
+
+class CustomJoinTest(QueryTest):
+ keep_mappers = False
+
+ def setup_mappers(self):
+ pass
+
+ def test_double_same_mappers(self):
+ """test aliasing of joins with a custom join condition"""
+ mapper(Address, addresses)
+ mapper(Order, orders, properties={
+ 'items':relation(Item, secondary=order_items, lazy=True, order_by=items.c.id),
+ })
+ mapper(Item, items)
+ mapper(User, users, properties = dict(
+ addresses = relation(Address, lazy=True),
+ open_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 1, users.c.id==orders.c.user_id), lazy=True),
+ closed_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 0, users.c.id==orders.c.user_id), lazy=True)
+ ))
+ q = create_session().query(User)
+
+ assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all()
+
+class SelfReferentialJoinTest(ORMTest):
+ def define_tables(self, metadata):
+ global nodes
+ nodes = Table('nodes', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('nodes.id')),
+ Column('data', String(30)))
+
+ def test_join(self):
+ class Node(Base):
+ def append(self, node):
+ self.children.append(node)
+
+ mapper(Node, nodes, properties={
+ 'children':relation(Node, lazy=True, join_depth=3,
+ backref=backref('parent', remote_side=[nodes.c.id])
+ )
+ })
+ sess = create_session()
+ n1 = Node(data='n1')
+ n1.append(Node(data='n11'))
+ n1.append(Node(data='n12'))
+ n1.append(Node(data='n13'))
+ n1.children[1].append(Node(data='n121'))
+ n1.children[1].append(Node(data='n122'))
+ n1.children[1].append(Node(data='n123'))
+ sess.save(n1)
+ sess.flush()
+ sess.clear()
+
+ # TODO: the aliasing of the join in query._join_to has to limit the aliasing
+ # among local_side / remote_side (add local_side as an attribute on PropertyLoader)
+ # also implement this idea in EagerLoader
+ node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
+ assert node.data=='n12'
+
+ node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first()
+ assert node.data=='n1'
+
+ node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\
+ join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
+ assert node.data == 'n122'
+
+class ExternalColumnsTest(QueryTest):
+ keep_mappers = False
+
+ def setup_mappers(self):
+ pass
+
+ def test_external_columns_bad(self):
+ """test that SA catches some common mis-configurations of external columns."""
+ f = (users.c.id * 2)
+ try:
+ mapper(User, users, properties={
+ 'concat': f,
+ })
+ class_mapper(User)
+ except exceptions.ArgumentError, e:
+ assert str(e) == "Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(f)
+ else:
+ raise 'expected ArgumentError'
+ clear_mappers()
+ try:
+ mapper(User, users, properties={
+ 'concat': column_property(users.c.id * 2),
+ })
+ except exceptions.ArgumentError, e:
+ assert str(e) == 'ColumnProperties must be named for the mapper to work with them. Try .label() to fix this'
+ else:
+ raise 'expected ArgumentError'
+
+ def test_external_columns_good(self):
+ """test querying mappings that reference external columns or selectables."""
+ mapper(User, users, properties={
+ 'concat': column_property((users.c.id * 2).label('concat')),
+ 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).label('count'))
+ })
+
+ mapper(Address, addresses, properties={
+ 'user':relation(User, lazy=True)
+ })
+
+ sess = create_session()
+ l = sess.query(User).select()
+ assert [
+ User(id=7, concat=14, count=1),
+ User(id=8, concat=16, count=3),
+ User(id=9, concat=18, count=1),
+ User(id=10, concat=20, count=0),
+ ] == l
+
+ address_result = [
+ Address(id=1, user=User(id=7, concat=14, count=1)),
+ Address(id=2, user=User(id=8, concat=16, count=3)),
+ Address(id=3, user=User(id=8, concat=16, count=3)),
+ Address(id=4, user=User(id=8, concat=16, count=3)),
+ Address(id=5, user=User(id=9, concat=18, count=1))
+ ]
+
+ assert address_result == sess.query(Address).all()
+
+ # run the eager version twice to test caching of aliased clauses
+ for x in range(2):
+ sess.clear()
+ def go():
+ assert address_result == sess.query(Address).options(eagerload('user')).all()
+ self.assert_sql_count(testbase.db, go, 1)
+
+ tuple_address_result = [(address, address.user) for address in address_result]
+
+ tuple_address_result == sess.query(Address).join('user').add_entity(User).all()
+
+ assert tuple_address_result == sess.query(Address).join('user', aliased=True, id='ualias').add_entity(User, id='ualias').all()
if __name__ == '__main__':
testbase.main()
diff --git a/test/orm/relationships.py b/test/orm/relationships.py
index 7c9bbc898..9fca22b24 100644
--- a/test/orm/relationships.py
+++ b/test/orm/relationships.py
@@ -1,12 +1,12 @@
import testbase
-import unittest, sys, datetime
-
-db = testbase.db
-
+import datetime
from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.orm import collections
+from sqlalchemy.orm.collections import collection
+from testlib import *
-
-class RelationTest(testbase.PersistTest):
+class RelationTest(PersistTest):
"""this is essentially an extension of the "dependency.py" topological sort test.
in this test, a table is dependent on two other tables that are otherwise unrelated to each other.
the dependency sort must insure that this childmost table is below both parent tables in the outcome
@@ -15,10 +15,8 @@ class RelationTest(testbase.PersistTest):
to subtle differences in program execution, this test case was exposing the bug whereas the simpler tests
were not."""
def setUpAll(self):
- global tbl_a
- global tbl_b
- global tbl_c
- global tbl_d
+ global metadata, tbl_a, tbl_b, tbl_c, tbl_d
+
metadata = MetaData()
tbl_a = Table("tbl_a", metadata,
Column("id", Integer, primary_key=True),
@@ -41,8 +39,8 @@ class RelationTest(testbase.PersistTest):
)
def setUp(self):
global session
- session = create_session(bind_to=testbase.db)
- conn = session.connect()
+ session = create_session(bind=testbase.db)
+ conn = testbase.db.connect()
conn.create(tbl_a)
conn.create(tbl_b)
conn.create(tbl_c)
@@ -80,14 +78,14 @@ class RelationTest(testbase.PersistTest):
session.save_or_update(b)
def tearDown(self):
- conn = session.connect()
+ conn = testbase.db.connect()
conn.drop(tbl_d)
conn.drop(tbl_c)
conn.drop(tbl_b)
conn.drop(tbl_a)
def tearDownAll(self):
- testbase.metadata.tables.clear()
+ metadata.drop_all(testbase.db)
def testDeleteRootTable(self):
session.flush()
@@ -99,7 +97,7 @@ class RelationTest(testbase.PersistTest):
session.delete(c) # fails
session.flush()
-class RelationTest2(testbase.PersistTest):
+class RelationTest2(PersistTest):
"""this test tests a relationship on a column that is included in multiple foreign keys,
as well as a self-referential relationship on a composite key where one column in the foreign key
is 'joined to itself'."""
@@ -216,7 +214,7 @@ class RelationTest2(testbase.PersistTest):
assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
-class RelationTest3(testbase.PersistTest):
+class RelationTest3(PersistTest):
def setUpAll(self):
global jobs, pageversions, pages, metadata, Job, Page, PageVersion, PageComment
import datetime
@@ -350,7 +348,7 @@ class RelationTest3(testbase.PersistTest):
s.delete(j)
s.flush()
-class RelationTest4(testbase.ORMTest):
+class RelationTest4(ORMTest):
"""test syncrules on foreign keys that are also primary"""
def define_tables(self, metadata):
global tableA, tableB
@@ -498,7 +496,7 @@ class RelationTest4(testbase.ORMTest):
assert a1 not in sess
assert b1 not in sess
-class RelationTest5(testbase.ORMTest):
+class RelationTest5(ORMTest):
"""test a map to a select that relates to a map to the table"""
def define_tables(self, metadata):
global items
@@ -554,7 +552,7 @@ class RelationTest5(testbase.ORMTest):
assert old.id == new.id
-class TypeMatchTest(testbase.ORMTest):
+class TypeMatchTest(ORMTest):
"""test errors raised when trying to add items whose type is not handled by a relation"""
def define_tables(self, metadata):
global a, b, c, d
@@ -672,7 +670,7 @@ class TypeMatchTest(testbase.ORMTest):
except exceptions.AssertionError, err:
assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B)
-class TypedAssociationTable(testbase.ORMTest):
+class TypedAssociationTable(ORMTest):
def define_tables(self, metadata):
global t1, t2, t3
@@ -722,7 +720,7 @@ class TypedAssociationTable(testbase.ORMTest):
assert t3.count().scalar() == 1
# TODO: move these tests to either attributes.py test or its own module
-class CustomCollectionsTest(testbase.ORMTest):
+class CustomCollectionsTest(ORMTest):
def define_tables(self, metadata):
global sometable, someothertable
sometable = Table('sometable', metadata,
@@ -745,7 +743,7 @@ class CustomCollectionsTest(testbase.ORMTest):
})
mapper(Bar, someothertable)
f = Foo()
- assert isinstance(f.bars.data, MyList)
+ assert isinstance(f.bars, MyList)
def testlazyload(self):
"""test that a 'set' can be used as a collection and can lazyload."""
class Foo(object):
@@ -769,23 +767,27 @@ class CustomCollectionsTest(testbase.ORMTest):
def testdict(self):
"""test that a 'dict' can be used as a collection and can lazyload."""
+
class Foo(object):
pass
class Bar(object):
pass
class AppenderDict(dict):
- def append(self, item):
+ @collection.appender
+ def set(self, item):
self[id(item)] = item
- def __iter__(self):
- return iter(self.values())
+ @collection.remover
+ def remove(self, item):
+ if id(item) in self:
+ del self[id(item)]
mapper(Foo, sometable, properties={
'bars':relation(Bar, collection_class=AppenderDict)
})
mapper(Bar, someothertable)
f = Foo()
- f.bars.append(Bar())
- f.bars.append(Bar())
+ f.bars.set(Bar())
+ f.bars.set(Bar())
sess = create_session()
sess.save(f)
sess.flush()
@@ -794,6 +796,44 @@ class CustomCollectionsTest(testbase.ORMTest):
assert len(list(f.bars)) == 2
f.bars.clear()
+ def testdictwrapper(self):
+ """test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
+
+ class Foo(object):
+ pass
+ class Bar(object):
+ def __init__(self, data): self.data = data
+
+ mapper(Foo, sometable, properties={
+ 'bars':relation(Bar,
+ collection_class=collections.column_mapped_collection(someothertable.c.data))
+ })
+ mapper(Bar, someothertable)
+
+ f = Foo()
+ col = collections.collection_adapter(f.bars)
+ col.append_with_event(Bar('a'))
+ col.append_with_event(Bar('b'))
+ sess = create_session()
+ sess.save(f)
+ sess.flush()
+ sess.clear()
+ f = sess.query(Foo).get(f.col1)
+ assert len(list(f.bars)) == 2
+
+ existing = set([id(b) for b in f.bars.values()])
+
+ col = collections.collection_adapter(f.bars)
+ col.append_with_event(Bar('b'))
+ f.bars['a'] = Bar('a')
+ sess.flush()
+ sess.clear()
+ f = sess.query(Foo).get(f.col1)
+ assert len(list(f.bars)) == 2
+
+ replaced = set([id(b) for b in f.bars.values()])
+ self.assert_(existing != replaced)
+
def testlist(self):
class Parent(object):
pass
@@ -811,13 +851,13 @@ class CustomCollectionsTest(testbase.ORMTest):
o = Child()
control.append(o)
p.children.append(o)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control.extend(o)
p.children.extend(o)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
assert control[0] == p.children[0]
@@ -826,92 +866,92 @@ class CustomCollectionsTest(testbase.ORMTest):
del control[1]
del p.children[1]
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = [Child()]
control[1:3] = o
p.children[1:3] = o
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control[1:3] = o
p.children[1:3] = o
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control[-1:-2] = o
p.children[-1:-2] = o
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = [Child(), Child(), Child(), Child()]
control[4:] = o
p.children[4:] = o
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(0, o)
p.children.insert(0, o)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(3, o)
p.children.insert(3, o)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(999, o)
p.children.insert(999, o)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
del control[0:1]
del p.children[0:1]
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
del control[1:1]
del p.children[1:1]
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
del control[1:3]
del p.children[1:3]
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
del control[7:]
del p.children[7:]
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
assert control.pop() == p.children.pop()
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
assert control.pop(0) == p.children.pop(0)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
assert control.pop(2) == p.children.pop(2)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
o = Child()
control.insert(2, o)
p.children.insert(2, o)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
control.remove(o)
p.children.remove(o)
- assert control == p.children.data
+ assert control == p.children
assert control == list(p.children)
def testobj(self):
@@ -922,9 +962,12 @@ class CustomCollectionsTest(testbase.ORMTest):
class MyCollection(object):
def __init__(self): self.data = []
+ @collection.appender
def append(self, value): self.data.append(value)
+ @collection.remover
+ def remove(self, value): self.data.remove(value)
+ @collection.iterator
def __iter__(self): return iter(self.data)
- def clear(self): self.data.clear()
mapper(Parent, sometable, properties={
'children':relation(Child, collection_class=MyCollection)
@@ -958,7 +1001,7 @@ class CustomCollectionsTest(testbase.ORMTest):
o = list(p2.children)
assert len(o) == 3
-class ViewOnlyTest(testbase.ORMTest):
+class ViewOnlyTest(ORMTest):
"""test a view_only mapping where a third table is pulled into the primary join condition,
using overlapping PK column names (should not produce "conflicting column" error)"""
def define_tables(self, metadata):
@@ -1009,7 +1052,7 @@ class ViewOnlyTest(testbase.ORMTest):
assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id])
assert set([x.id for x in c1.t2_view]) == set([c2b.id])
-class ViewOnlyTest2(testbase.ORMTest):
+class ViewOnlyTest2(ORMTest):
"""test a view_only mapping where a third table is pulled into the primary join condition,
using non-overlapping PK column names (should not produce "mapper has no column X" error)"""
def define_tables(self, metadata):
diff --git a/test/orm/session.py b/test/orm/session.py
index 762722ecc..433279673 100644
--- a/test/orm/session.py
+++ b/test/orm/session.py
@@ -1,13 +1,9 @@
-from testbase import AssertMixin
import testbase
-import unittest, sys, datetime
-
-import tables
-from tables import *
-
-db = testbase.db
from sqlalchemy import *
-
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
class SessionTest(AssertMixin):
def setUpAll(self):
@@ -25,7 +21,7 @@ class SessionTest(AssertMixin):
c = testbase.db.connect()
class User(object):pass
mapper(User, users)
- s = create_session(bind_to=c)
+ s = create_session(bind=c)
s.save(User())
s.flush()
c.execute("select * from users")
@@ -38,6 +34,30 @@ class SessionTest(AssertMixin):
s.user_name = 'some other user'
s.flush()
+ def test_close_two(self):
+ c = testbase.db.connect()
+ try:
+ class User(object):pass
+ mapper(User, users)
+ s = create_session(bind=c)
+ s.begin()
+ tran = s.transaction
+ s.save(User())
+ s.flush()
+ c.execute("select * from users")
+ u = User()
+ s.save(u)
+ s.user_name = 'some user'
+ s.flush()
+ u = User()
+ s.save(u)
+ s.user_name = 'some other user'
+ s.flush()
+ assert s.transaction is tran
+ tran.close()
+ finally:
+ c.close()
+
def test_expunge_cascade(self):
tables.data()
mapper(Address, addresses)
@@ -52,49 +72,209 @@ class SessionTest(AssertMixin):
# then see if expunge fails
session.expunge(u)
-
+
+ @testing.unsupported('sqlite')
def test_transaction(self):
class User(object):pass
mapper(User, users)
- sess = create_session()
- transaction = sess.create_transaction()
+ conn1 = testbase.db.connect()
+ conn2 = testbase.db.connect()
+
+ sess = create_session(transactional=True, bind=conn1)
+ u = User()
+ sess.save(u)
+ sess.flush()
+ assert conn1.execute("select count(1) from users").scalar() == 1
+ assert conn2.execute("select count(1) from users").scalar() == 0
+ sess.commit()
+ assert conn1.execute("select count(1) from users").scalar() == 1
+ assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
+
+ @testing.unsupported('sqlite')
+ def test_autoflush(self):
+ class User(object):pass
+ mapper(User, users)
+ conn1 = testbase.db.connect()
+ conn2 = testbase.db.connect()
+
+ sess = create_session(autoflush=True, bind=conn1)
+ u = User()
+ u.user_name='ed'
+ sess.save(u)
+ u2 = sess.query(User).filter_by(user_name='ed').one()
+ assert u2 is u
+ assert conn1.execute("select count(1) from users").scalar() == 1
+ assert conn2.execute("select count(1) from users").scalar() == 0
+ sess.commit()
+ assert conn1.execute("select count(1) from users").scalar() == 1
+ assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
+
+ @testing.unsupported('sqlite')
+ def test_autoflush_unbound(self):
+ class User(object):pass
+ mapper(User, users)
+
try:
+ sess = create_session(autoflush=True)
u = User()
+ u.user_name='ed'
sess.save(u)
+ u2 = sess.query(User).filter_by(user_name='ed').one()
+ assert u2 is u
+ assert sess.execute("select count(1) from users", mapper=User).scalar() == 1
+ assert testbase.db.connect().execute("select count(1) from users").scalar() == 0
+ sess.commit()
+ assert sess.execute("select count(1) from users", mapper=User).scalar() == 1
+ assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
+ except:
+ sess.rollback()
+ raise
+
+ def test_autoflush_2(self):
+ class User(object):pass
+ mapper(User, users)
+ conn1 = testbase.db.connect()
+ conn2 = testbase.db.connect()
+
+ sess = create_session(autoflush=True, bind=conn1)
+ u = User()
+ u.user_name='ed'
+ sess.save(u)
+ sess.commit()
+ assert conn1.execute("select count(1) from users").scalar() == 1
+ assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
+
+ def test_external_joined_transaction(self):
+ class User(object):pass
+ mapper(User, users)
+ conn = testbase.db.connect()
+ trans = conn.begin()
+ sess = create_session(bind=conn)
+ sess.begin()
+ u = User()
+ sess.save(u)
+ sess.flush()
+ sess.commit() # commit does nothing
+ trans.rollback() # rolls back
+ assert len(sess.query(User).select()) == 0
+
+ @testing.supported('postgres', 'mysql')
+ def test_external_nested_transaction(self):
+ class User(object):pass
+ mapper(User, users)
+ try:
+ conn = testbase.db.connect()
+ trans = conn.begin()
+ sess = create_session(bind=conn)
+ u1 = User()
+ sess.save(u1)
sess.flush()
- sess.delete(u)
- sess.save(User())
+
+ sess.begin_nested()
+ u2 = User()
+ sess.save(u2)
sess.flush()
- # TODO: assertion ?
- transaction.commit()
+ sess.rollback()
+
+ trans.commit()
+ assert len(sess.query(User).select()) == 1
except:
- transaction.rollback()
+ conn.close()
+ raise
+
+ @testing.supported('postgres', 'mysql')
+ def test_twophase(self):
+ # TODO: mock up a failure condition here
+ # to ensure a rollback succeeds
+ class User(object):pass
+ class Address(object):pass
+ mapper(User, users)
+ mapper(Address, addresses)
+
+ engine2 = create_engine(testbase.db.url)
+ sess = create_session(twophase=True)
+ sess.bind_mapper(User, testbase.db)
+ sess.bind_mapper(Address, engine2)
+ sess.begin()
+ u1 = User()
+ a1 = Address()
+ sess.save(u1)
+ sess.save(a1)
+ sess.commit()
+ sess.close()
+ engine2.dispose()
+ assert users.count().scalar() == 1
+ assert addresses.count().scalar() == 1
+
+
+
+ def test_joined_transaction(self):
+ class User(object):pass
+ mapper(User, users)
+ sess = create_session()
+ sess.begin()
+ sess.begin()
+ u = User()
+ sess.save(u)
+ sess.flush()
+ sess.commit() # commit does nothing
+ sess.rollback() # rolls back
+ assert len(sess.query(User).select()) == 0
+ @testing.supported('postgres', 'mysql')
def test_nested_transaction(self):
class User(object):pass
mapper(User, users)
sess = create_session()
- transaction = sess.create_transaction()
- trans2 = sess.create_transaction()
+ sess.begin()
+
u = User()
sess.save(u)
sess.flush()
- trans2.commit()
- transaction.rollback()
- assert len(sess.query(User).select()) == 0
+
+ sess.begin_nested() # nested transaction
+
+ u2 = User()
+ sess.save(u2)
+ sess.flush()
+
+ sess.rollback()
+
+ sess.commit()
+ assert len(sess.query(User).select()) == 1
+
+ @testing.supported('postgres', 'mysql')
+ def test_nested_autotrans(self):
+ class User(object):pass
+ mapper(User, users)
+ sess = create_session(transactional=True)
+ u = User()
+ sess.save(u)
+ sess.flush()
+
+ sess.begin_nested() # nested transaction
+
+ u2 = User()
+ sess.save(u2)
+ sess.flush()
+
+ sess.rollback()
+
+ sess.commit()
+ assert len(sess.query(User).select()) == 1
def test_bound_connection(self):
class User(object):pass
mapper(User, users)
c = testbase.db.connect()
sess = create_session(bind=c)
- transaction = sess.create_transaction()
- trans2 = sess.create_transaction()
+ sess.create_transaction()
+ transaction = sess.transaction
u = User()
sess.save(u)
sess.flush()
- assert transaction.get_or_add(testbase.db) is trans2.get_or_add(testbase.db) is transaction.get_or_add(c) is trans2.get_or_add(c) is c
-
+ assert transaction.get_or_add(testbase.db) is transaction.get_or_add(c) is c
+
try:
transaction.add(testbase.db.connect())
assert False
@@ -112,34 +292,10 @@ class SessionTest(AssertMixin):
assert False
except exceptions.InvalidRequestError, e:
assert str(e) == "Session already has a Connection associated for the given Engine"
-
- trans2.commit()
+
transaction.rollback()
assert len(sess.query(User).select()) == 0
-
- def test_close_two(self):
- c = testbase.db.connect()
- try:
- class User(object):pass
- mapper(User, users)
- s = create_session(bind_to=c)
- tran = s.create_transaction()
- s.save(User())
- s.flush()
- c.execute("select * from users")
- u = User()
- s.save(u)
- s.user_name = 'some user'
- s.flush()
- u = User()
- s.save(u)
- s.user_name = 'some other user'
- s.flush()
- assert s.transaction is tran
- tran.close()
- finally:
- c.close()
-
+
def test_update(self):
"""test that the update() method functions and doesnet blow away changes"""
tables.delete()
@@ -164,7 +320,7 @@ class SessionTest(AssertMixin):
user = s.query(User).selectone()
assert user.user_name == 'fred'
- # insure its not dirty if no changes occur
+ # ensure its not dirty if no changes occur
s.clear()
assert user not in s
s.update(user)
diff --git a/test/orm/sessioncontext.py b/test/orm/sessioncontext.py
index 83bc2f2bf..7a60b47c7 100644
--- a/test/orm/sessioncontext.py
+++ b/test/orm/sessioncontext.py
@@ -1,15 +1,15 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
+import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
from sqlalchemy.orm.session import object_session, Session
-from sqlalchemy import *
-import testbase
+from testlib import *
+
metadata = MetaData()
users = Table('users', metadata,
Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
Column('user_name', String(40)),
- mysql_engine='innodb'
)
class SessionContextTest(AssertMixin):
diff --git a/test/orm/sharding/__init__.py b/test/orm/sharding/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/test/orm/sharding/__init__.py
diff --git a/test/orm/sharding/alltests.py b/test/orm/sharding/alltests.py
new file mode 100644
index 000000000..0cdb838a9
--- /dev/null
+++ b/test/orm/sharding/alltests.py
@@ -0,0 +1,18 @@
+import testbase
+import unittest
+
+def suite():
+ modules_to_test = (
+ 'orm.sharding.shard',
+ )
+ alltests = unittest.TestSuite()
+ for name in modules_to_test:
+ mod = __import__(name)
+ for token in name.split('.')[1:]:
+ mod = getattr(mod, token)
+ alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
+ return alltests
+
+
+if __name__ == '__main__':
+ testbase.main(suite())
diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py
new file mode 100644
index 000000000..faa980cc2
--- /dev/null
+++ b/test/orm/sharding/shard.py
@@ -0,0 +1,154 @@
+import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+
+from sqlalchemy.orm.shard import ShardedSession
+from sqlalchemy.sql import ColumnOperators
+import datetime, operator, os
+from testlib import PersistTest
+
+# TODO: ShardTest can be turned into a base for further subclasses
+
+class ShardTest(PersistTest):
+ def setUpAll(self):
+ global db1, db2, db3, db4, weather_locations, weather_reports
+
+ db1 = create_engine('sqlite:///shard1.db')
+ db2 = create_engine('sqlite:///shard2.db')
+ db3 = create_engine('sqlite:///shard3.db')
+ db4 = create_engine('sqlite:///shard4.db')
+
+ meta = MetaData()
+ ids = Table('ids', meta,
+ Column('nextid', Integer, nullable=False))
+
+ def id_generator(ctx):
+ # in reality, might want to use a separate transaction for this.
+ c = db1.connect()
+ nextid = c.execute(ids.select(for_update=True)).scalar()
+ c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1}))
+ return nextid
+
+ weather_locations = Table("weather_locations", meta,
+ Column('id', Integer, primary_key=True, default=id_generator),
+ Column('continent', String(30), nullable=False),
+ Column('city', String(50), nullable=False)
+ )
+
+ weather_reports = Table("weather_reports", meta,
+ Column('id', Integer, primary_key=True),
+ Column('location_id', Integer, ForeignKey('weather_locations.id')),
+ Column('temperature', Float),
+ Column('report_time', DateTime, default=datetime.datetime.now),
+ )
+
+ for db in (db1, db2, db3, db4):
+ meta.create_all(db)
+
+ db1.execute(ids.insert(), nextid=1)
+
+ self.setup_session()
+ self.setup_mappers()
+
+ def tearDownAll(self):
+ for i in range(1,5):
+ os.remove("shard%d.db" % i)
+
+ def setup_session(self):
+ global create_session
+
+ shard_lookup = {
+ 'North America':'north_america',
+ 'Asia':'asia',
+ 'Europe':'europe',
+ 'South America':'south_america'
+ }
+
+ def shard_chooser(mapper, instance):
+ if isinstance(instance, WeatherLocation):
+ return shard_lookup[instance.continent]
+ else:
+ return shard_chooser(mapper, instance.location)
+
+ def id_chooser(ident):
+ return ['north_america', 'asia', 'europe', 'south_america']
+
+ def query_chooser(query):
+ ids = []
+
+ class FindContinent(sql.ClauseVisitor):
+ def visit_binary(self, binary):
+ if binary.left is weather_locations.c.continent:
+ if binary.operator == operator.eq:
+ ids.append(shard_lookup[binary.right.value])
+ elif binary.operator == ColumnOperators.in_op:
+ for bind in binary.right.clauses:
+ ids.append(shard_lookup[bind.value])
+
+ FindContinent().traverse(query._criterion)
+ if len(ids) == 0:
+ return ['north_america', 'asia', 'europe', 'south_america']
+ else:
+ return ids
+
+ def create_session():
+ s = ShardedSession(shard_chooser, id_chooser, query_chooser)
+ s.bind_shard('north_america', db1)
+ s.bind_shard('asia', db2)
+ s.bind_shard('europe', db3)
+ s.bind_shard('south_america', db4)
+ return s
+
+ def setup_mappers(self):
+ global WeatherLocation, Report
+
+ class WeatherLocation(object):
+ def __init__(self, continent, city):
+ self.continent = continent
+ self.city = city
+
+ class Report(object):
+ def __init__(self, temperature):
+ self.temperature = temperature
+
+ mapper(WeatherLocation, weather_locations, properties={
+ 'reports':relation(Report, backref='location')
+ })
+
+ mapper(Report, weather_reports)
+
+ def test_roundtrip(self):
+ tokyo = WeatherLocation('Asia', 'Tokyo')
+ newyork = WeatherLocation('North America', 'New York')
+ toronto = WeatherLocation('North America', 'Toronto')
+ london = WeatherLocation('Europe', 'London')
+ dublin = WeatherLocation('Europe', 'Dublin')
+ brasilia = WeatherLocation('South America', 'Brasila')
+ quito = WeatherLocation('South America', 'Quito')
+
+ tokyo.reports.append(Report(80.0))
+ newyork.reports.append(Report(75))
+ quito.reports.append(Report(85))
+
+ sess = create_session()
+ for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
+ sess.save(c)
+ sess.flush()
+
+ sess.clear()
+
+ t = sess.query(WeatherLocation).get(tokyo.id)
+ assert t.city == tokyo.city
+ assert t.reports[0].temperature == 80.0
+
+ north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America')
+ assert set([c.city for c in north_american_cities]) == set(['New York', 'Toronto'])
+
+ asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_('Europe', 'Asia'))
+ assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin'])
+
+
+
+if __name__ == '__main__':
+ testbase.main()
+ \ No newline at end of file
diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py
index 6ba3f8c4b..ae626db84 100644
--- a/test/orm/unitofwork.py
+++ b/test/orm/unitofwork.py
@@ -1,13 +1,14 @@
-from testbase import PersistTest, AssertMixin
-from sqlalchemy import *
import testbase
import pickleable
+from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.orm.mapper import global_extensions
from sqlalchemy.orm import util as ormutil
from sqlalchemy.ext.sessioncontext import SessionContext
import sqlalchemy.ext.assignmapper as assignmapper
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
+from testlib import tables
"""tests unitofwork operations"""
@@ -26,6 +27,7 @@ class UnitOfWorkTest(AssertMixin):
class HistoryTest(UnitOfWorkTest):
def setUpAll(self):
+ tables.metadata.bind = testbase.db
UnitOfWorkTest.setUpAll(self)
users.create()
addresses.create()
@@ -61,7 +63,7 @@ class VersioningTest(UnitOfWorkTest):
UnitOfWorkTest.setUpAll(self)
ctx.current.clear()
global version_table
- version_table = Table('version_test', db,
+ version_table = Table('version_test', MetaData(testbase.db),
Column('id', Integer, Sequence('version_test_seq'), primary_key=True ),
Column('version_id', Integer, nullable=False),
Column('value', String(40), nullable=False)
@@ -253,9 +255,9 @@ class MutableTypesTest(UnitOfWorkTest):
ctx.current.flush()
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
f1.value = unicode('someothervalue')
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"UPDATE mutabletest SET value=:value WHERE mutabletest.id = :mutabletest_id",
{'mutabletest_id': f1.id, 'value': u'someothervalue'}
@@ -263,7 +265,7 @@ class MutableTypesTest(UnitOfWorkTest):
])
f1.value = unicode('hi')
f1.data.x = 9
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"UPDATE mutabletest SET data=:data, value=:value WHERE mutabletest.id = :mutabletest_id",
{'mutabletest_id': f1.id, 'value': u'hi', 'data':f1.data}
@@ -281,7 +283,7 @@ class MutableTypesTest(UnitOfWorkTest):
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
ctx.current.clear()
@@ -289,12 +291,12 @@ class MutableTypesTest(UnitOfWorkTest):
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
f2.data.y = 19
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
ctx.current.clear()
f3 = ctx.current.query(Foo).get_by(id=f1.id)
@@ -303,7 +305,7 @@ class MutableTypesTest(UnitOfWorkTest):
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testunicode(self):
"""test that two equivalent unicode values dont get flagged as changed.
@@ -320,47 +322,42 @@ class MutableTypesTest(UnitOfWorkTest):
f1.value = u'hi'
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
class PKTest(UnitOfWorkTest):
def setUpAll(self):
UnitOfWorkTest.setUpAll(self)
- global table
- global table2
- global table3
+ global table, table2, table3, metadata
+ metadata = MetaData(testbase.db)
table = Table(
- 'multipk', db,
+ 'multipk', metadata,
Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True),
Column('multi_rev', Integer, primary_key=True),
Column('name', String(50), nullable=False),
Column('value', String(100))
)
- table2 = Table('multipk2', db,
+ table2 = Table('multipk2', metadata,
Column('pk_col_1', String(30), primary_key=True),
Column('pk_col_2', String(30), primary_key=True),
Column('data', String(30), )
)
- table3 = Table('multipk3', db,
+ table3 = Table('multipk3', metadata,
Column('pri_code', String(30), key='primary', primary_key=True),
Column('sec_code', String(30), key='secondary', primary_key=True),
Column('date_assigned', Date, key='assigned', primary_key=True),
Column('data', String(30), )
)
- table.create()
- table2.create()
- table3.create()
+ metadata.create_all()
def tearDownAll(self):
- table.drop()
- table2.drop()
- table3.drop()
+ metadata.drop_all()
UnitOfWorkTest.tearDownAll(self)
# not support on sqlite since sqlite's auto-pk generation only works with
# single column primary keys
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testprimarykey(self):
class Entry(object):
pass
@@ -448,7 +445,7 @@ class ForeignPKTest(UnitOfWorkTest):
},
)
- assert list(m2.props['sites'].foreign_keys) == [peoplesites.c.person]
+ assert list(m2.get_property('sites').foreign_keys) == [peoplesites.c.person]
p = Person()
p.person = 'im the key'
p.firstname = 'asdf'
@@ -466,7 +463,7 @@ class PassiveDeletesTest(UnitOfWorkTest):
mytable = Table('mytable', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)),
- mysql_engine='InnoDB'
+ test_needs_fk=True,
)
myothertable = Table('myothertable', metadata,
@@ -474,7 +471,7 @@ class PassiveDeletesTest(UnitOfWorkTest):
Column('parent_id', Integer),
Column('data', String(30)),
ForeignKeyConstraint(['parent_id'],['mytable.id'], ondelete="CASCADE"),
- mysql_engine='InnoDB'
+ test_needs_fk=True,
)
metadata.create_all()
@@ -482,7 +479,7 @@ class PassiveDeletesTest(UnitOfWorkTest):
metadata.drop_all()
UnitOfWorkTest.tearDownAll(self)
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testbasic(self):
class MyClass(object):
pass
@@ -519,6 +516,7 @@ class DefaultTest(UnitOfWorkTest):
defaults back from the engine."""
def setUpAll(self):
UnitOfWorkTest.setUpAll(self)
+ db = testbase.db
use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
if use_string_defaults:
@@ -529,21 +527,21 @@ class DefaultTest(UnitOfWorkTest):
hohotype = Integer
self.hohoval = 9
self.althohoval = 15
- self.table = Table('default_test', db,
+ global default_table
+ metadata = MetaData(db)
+ default_table = Table('default_test', metadata,
Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True),
Column('hoho', hohotype, PassiveDefault(str(self.hohoval))),
Column('counter', Integer, PassiveDefault("7")),
Column('foober', String(30), default="im foober", onupdate="im the update")
)
- self.table.create()
+ default_table.create()
def tearDownAll(self):
- self.table.drop()
+ default_table.drop()
UnitOfWorkTest.tearDownAll(self)
- def setUp(self):
- self.table = Table('default_test', db)
def testinsert(self):
class Hoho(object):pass
- assign_mapper(Hoho, self.table)
+ assign_mapper(Hoho, default_table)
h1 = Hoho(hoho=self.althohoval)
h2 = Hoho(counter=12)
h3 = Hoho(hoho=self.althohoval, counter=12)
@@ -571,7 +569,7 @@ class DefaultTest(UnitOfWorkTest):
def testinsertnopostfetch(self):
# populates the PassiveDefaults explicitly so there is no "post-update"
class Hoho(object):pass
- assign_mapper(Hoho, self.table)
+ assign_mapper(Hoho, default_table)
h1 = Hoho(hoho="15", counter="15")
ctx.current.flush()
self.assert_(h1.hoho=="15")
@@ -580,7 +578,7 @@ class DefaultTest(UnitOfWorkTest):
def testupdate(self):
class Hoho(object):pass
- assign_mapper(Hoho, self.table)
+ assign_mapper(Hoho, default_table)
h1 = Hoho()
ctx.current.flush()
self.assert_(h1.foober == 'im foober')
@@ -613,8 +611,7 @@ class OneToManyTest(UnitOfWorkTest):
a2 = Address()
a2.email_address = 'lala@test.org'
u.addresses.append(a2)
- self.echo( repr(u.addresses))
- self.echo( repr(u.addresses.added_items()))
+ print repr(u.addresses)
ctx.current.flush()
usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall()
@@ -664,7 +661,7 @@ class OneToManyTest(UnitOfWorkTest):
u2.user_name = 'user2modified'
u1.addresses.append(a3)
del u1.addresses[0]
- self.assert_sql(db, lambda: ctx.current.flush(),
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(),
[
(
"UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id",
@@ -836,7 +833,7 @@ class SaveTest(UnitOfWorkTest):
# assert the first one retreives the same from the identity map
nu = ctx.current.get(m, u.user_id)
- self.echo( "U: " + repr(u) + "NU: " + repr(nu))
+ print "U: " + repr(u) + "NU: " + repr(nu)
self.assert_(u is nu)
# clear out the identity map, so next get forces a SELECT
@@ -917,7 +914,7 @@ class SaveTest(UnitOfWorkTest):
u.user_name = ""
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testmultitable(self):
"""tests a save of an object where each instance spans two tables. also tests
@@ -935,8 +932,7 @@ class SaveTest(UnitOfWorkTest):
u.email = 'multi@test.org'
ctx.current.flush()
- id = m.identity(u)
- print id
+ id = m.primary_key_from_instance(u)
ctx.current.clear()
@@ -1043,7 +1039,7 @@ class ManyToOneTest(UnitOfWorkTest):
objects[2].email_address = 'imnew@foo.bar'
objects[3].user = User()
objects[3].user.user_name = 'imnewlyadded'
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'imnewlyadded'}
@@ -1213,7 +1209,7 @@ class ManyToManyTest(UnitOfWorkTest):
k = Keyword()
k.name = 'yellow'
objects[5].keywords.append(k)
- self.assert_sql(db, lambda:ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda:ctx.current.flush(), [
{
"UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}
@@ -1242,7 +1238,7 @@ class ManyToManyTest(UnitOfWorkTest):
objects[2].keywords.append(k)
dkid = objects[5].keywords[1].keyword_id
del objects[5].keywords[1]
- self.assert_sql(db, lambda:ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda:ctx.current.flush(), [
(
"DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id",
[{'item_id': objects[5].item_id, 'keyword_id': dkid}]
@@ -1412,7 +1408,6 @@ class ManyToManyTest(UnitOfWorkTest):
k.user_name = 'keyworduser'
k.keyword_name = 'a keyword'
ctx.current.flush()
- print m.instance_key(k)
id = (k.user_id, k.keyword_id)
ctx.current.clear()
@@ -1427,7 +1422,7 @@ class SaveTest2(UnitOfWorkTest):
ctx.current.clear()
clear_mappers()
global meta, users, addresses
- meta = MetaData(db)
+ meta = MetaData(testbase.db)
users = Table('users', meta,
Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
Column('user_name', String(20)),
@@ -1459,7 +1454,7 @@ class SaveTest2(UnitOfWorkTest):
a.user = User()
a.user.user_name = elem['user_name']
objects.append(a)
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'thesub'}
@@ -1498,30 +1493,32 @@ class SaveTest2(UnitOfWorkTest):
]
)
-class SaveTest3(UnitOfWorkTest):
+class SaveTest3(UnitOfWorkTest):
def setUpAll(self):
+ global st3_metadata, t1, t2, t3
+
UnitOfWorkTest.setUpAll(self)
- global metadata, t1, t2, t3
- metadata = testbase.metadata
- t1 = Table('items', metadata,
+
+ st3_metadata = MetaData(testbase.db)
+ t1 = Table('items', st3_metadata,
Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
Column('item_name', VARCHAR(50)),
)
- t3 = Table('keywords', metadata,
+ t3 = Table('keywords', st3_metadata,
Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True),
Column('name', VARCHAR(50)),
)
- t2 = Table('assoc', metadata,
+ t2 = Table('assoc', st3_metadata,
Column('item_id', INT, ForeignKey("items")),
Column('keyword_id', INT, ForeignKey("keywords")),
Column('foo', Boolean, default=True)
)
- metadata.create_all()
+ st3_metadata.create_all()
def tearDownAll(self):
- metadata.drop_all()
+ st3_metadata.drop_all()
UnitOfWorkTest.tearDownAll(self)
def setUp(self):
diff --git a/test/perf/cascade_speed.py b/test/perf/cascade_speed.py
index d2e741442..34d046381 100644
--- a/test/perf/cascade_speed.py
+++ b/test/perf/cascade_speed.py
@@ -1,5 +1,7 @@
import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
from timeit import Timer
import sys
diff --git a/test/perf/masscreate.py b/test/perf/masscreate.py
index e603e2c00..346a725e3 100644
--- a/test/perf/masscreate.py
+++ b/test/perf/masscreate.py
@@ -1,8 +1,7 @@
# times how long it takes to create 26000 objects
-import sys
-sys.path.insert(0, './lib/')
+import testbase
-from sqlalchemy.attributes import *
+from sqlalchemy.orm.attributes import *
import time
import gc
diff --git a/test/perf/masscreate2.py b/test/perf/masscreate2.py
index 3a68f3612..2e29a6327 100644
--- a/test/perf/masscreate2.py
+++ b/test/perf/masscreate2.py
@@ -1,11 +1,9 @@
-import sys
-sys.path.insert(0, './lib/')
-
+import testbase
import gc
import random, string
-from sqlalchemy.attributes import *
+from sqlalchemy.orm.attributes import *
# with this test, run top. make sure the Python process doenst grow in size arbitrarily.
diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py
index 9d77fed54..f1c0f292b 100644
--- a/test/perf/masseagerload.py
+++ b/test/perf/masseagerload.py
@@ -1,62 +1,54 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import StringIO
import testbase
-import gc
-import time
-import hotshot
-import hotshot.stats
-
-db = testbase.db
+import hotshot, hotshot.stats
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
NUM = 500
DIVISOR = 50
-class LoadTest(AssertMixin):
- def setUpAll(self):
- global items, meta,subitems
- meta = MetaData(db)
- items = Table('items', meta,
- Column('item_id', Integer, primary_key=True),
- Column('value', String(100)))
- subitems = Table('subitems', meta,
- Column('sub_id', Integer, primary_key=True),
- Column('parent_id', Integer, ForeignKey('items.item_id')),
- Column('value', String(100)))
- meta.create_all()
- def tearDownAll(self):
- meta.drop_all()
- def setUp(self):
- clear_mappers()
+meta = MetaData(testbase.db)
+items = Table('items', meta,
+ Column('item_id', Integer, primary_key=True),
+ Column('value', String(100)))
+subitems = Table('subitems', meta,
+ Column('sub_id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('items.item_id')),
+ Column('value', String(100)))
+
+class Item(object):pass
+class SubItem(object):pass
+mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)})
+mapper(SubItem, subitems)
+
+def load():
+ global l
+ l = []
+ for x in range(1,NUM/DIVISOR + 1):
+ l.append({'item_id':x, 'value':'this is item #%d' % x})
+ #print l
+ items.insert().execute(*l)
+ for x in range(1, NUM/DIVISOR + 1):
l = []
- for x in range(1,NUM/DIVISOR + 1):
- l.append({'item_id':x, 'value':'this is item #%d' % x})
+ for y in range(1, DIVISOR + 1):
+ z = ((x-1) * DIVISOR) + y
+ l.append({'sub_id':z,'value':'this is item #%d' % z, 'parent_id':x})
#print l
- items.insert().execute(*l)
- for x in range(1, NUM/DIVISOR + 1):
- l = []
- for y in range(1, DIVISOR + 1):
- z = ((x-1) * DIVISOR) + y
- l.append({'sub_id':z,'value':'this is iteim #%d' % z, 'parent_id':x})
- #print l
- subitems.insert().execute(*l)
- def testload(self):
- class Item(object):pass
- class SubItem(object):pass
- mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)})
- mapper(SubItem, subitems)
- sess = create_session()
- prof = hotshot.Profile("masseagerload.prof")
- prof.start()
- query = sess.query(Item)
- l = query.select()
- print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
- prof.stop()
- prof.close()
- stats = hotshot.stats.load("masseagerload.prof")
- stats.sort_stats('time', 'calls')
- stats.print_stats()
-
-if __name__ == "__main__":
- testbase.main()
+ subitems.insert().execute(*l)
+
+@profiling.profiled('masseagerload', always=True)
+def masseagerload(session):
+ query = session.query(Item)
+ l = query.select()
+ print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
+
+def all():
+ meta.create_all()
+ try:
+ load()
+ masseagerload(create_session())
+ finally:
+ meta.drop_all()
+
+if __name__ == '__main__':
+ all()
diff --git a/test/perf/massload.py b/test/perf/massload.py
index 3530e4a65..92cf0fe92 100644
--- a/test/perf/massload.py
+++ b/test/perf/massload.py
@@ -1,13 +1,10 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import sqlalchemy.orm.attributes as attributes
-import StringIO
import testbase
-import gc
import time
-
-db = testbase.db
+#import gc
+#import sqlalchemy.orm.attributes as attributes
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
NUM = 2500
@@ -20,7 +17,7 @@ for best results, dont run with sqlite :memory: database, and keep an eye on top
class LoadTest(AssertMixin):
def setUpAll(self):
global items, meta
- meta = MetaData(db)
+ meta = MetaData(testbase.db)
items = Table('items', meta,
Column('item_id', Integer, primary_key=True),
Column('value', String(100)))
@@ -28,8 +25,6 @@ class LoadTest(AssertMixin):
def tearDownAll(self):
items.drop()
def setUp(self):
- objectstore.clear()
- clear_mappers()
for x in range(1,NUM/500+1):
l = []
for y in range(x*500-500 + 1, x*500 + 1):
diff --git a/test/perf/massload2.py b/test/perf/massload2.py
index 1506ca503..d6424eb07 100644
--- a/test/perf/massload2.py
+++ b/test/perf/massload2.py
@@ -7,6 +7,7 @@ try:
except:
pass
from sqlalchemy import *
+from testbase import Table, Column
import time
metadata = create_engine('sqlite://', echo=True)
diff --git a/test/perf/masssave.py b/test/perf/masssave.py
index 5690eac3f..dd03f3962 100644
--- a/test/perf/masssave.py
+++ b/test/perf/masssave.py
@@ -1,20 +1,16 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import sqlalchemy.attributes as attributes
-import StringIO
import testbase
-import gc
-import sqlalchemy.orm.session
import types
-db = testbase.db
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
NUM = 250000
class SaveTest(AssertMixin):
def setUpAll(self):
global items, metadata
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
items = Table('items', metadata,
Column('item_id', Integer, primary_key=True),
Column('value', String(100)))
diff --git a/test/perf/ormsession.py b/test/perf/ormsession.py
new file mode 100644
index 000000000..a9d310ef6
--- /dev/null
+++ b/test/perf/ormsession.py
@@ -0,0 +1,225 @@
+import testbase
+import time
+from datetime import datetime
+
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.profiling import profiled
+
+class Item(object):
+ def __repr__(self):
+ return 'Item<#%s "%s">' % (self.id, self.name)
+class SubItem(object):
+ def __repr__(self):
+ return 'SubItem<#%s "%s">' % (self.id, self.name)
+class Customer(object):
+ def __repr__(self):
+ return 'Customer<#%s "%s">' % (self.id, self.name)
+class Purchase(object):
+ def __repr__(self):
+ return 'Purchase<#%s "%s">' % (self.id, self.purchase_date)
+
+items, subitems, customers, purchases, purchaseitems = \
+ None, None, None, None, None
+
+metadata = MetaData()
+
+@profiled('table')
+def define_tables():
+ global items, subitems, customers, purchases, purchaseitems
+ items = Table('items', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(100)),
+ test_needs_acid=True)
+ subitems = Table('subitems', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('item_id', Integer, ForeignKey('items.id'),
+ nullable=False),
+ Column('name', String(100), PassiveDefault('no name')),
+ test_needs_acid=True)
+ customers = Table('customers', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(100)),
+ *[Column("col_%s" % chr(i), String(64), default=str(i))
+ for i in range(97,117)],
+ **dict(test_needs_acid=True))
+ purchases = Table('purchases', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('customer_id', Integer,
+ ForeignKey('customers.id'), nullable=False),
+ Column('purchase_date', DateTime,
+ default=datetime.now),
+ test_needs_acid=True)
+ purchaseitems = Table('purchaseitems', metadata,
+ Column('purchase_id', Integer,
+ ForeignKey('purchases.id'),
+ nullable=False, primary_key=True),
+ Column('item_id', Integer, ForeignKey('items.id'),
+ nullable=False, primary_key=True),
+ test_needs_acid=True)
+
+@profiled('mapper')
+def setup_mappers():
+ mapper(Item, items, properties={
+ 'subitems': relation(SubItem, backref='item', lazy=True)
+ })
+ mapper(SubItem, subitems)
+ mapper(Customer, customers, properties={
+ 'purchases': relation(Purchase, lazy=True, backref='customer')
+ })
+ mapper(Purchase, purchases, properties={
+ 'items': relation(Item, lazy=True, secondary=purchaseitems)
+ })
+
+@profiled('inserts')
+def insert_data():
+ q_items = 1000
+ q_sub_per_item = 10
+ q_customers = 1000
+
+ con = testbase.db.connect()
+
+ transaction = con.begin()
+ data, subdata = [], []
+ for item_id in xrange(1, q_items + 1):
+ data.append({'name': "item number %s" % item_id})
+ for subitem_id in xrange(1, (item_id % q_sub_per_item) + 1):
+ subdata.append({'item_id': item_id,
+ 'name': "subitem number %s" % subitem_id})
+ if item_id % 100 == 0:
+ items.insert().execute(*data)
+ subitems.insert().execute(*subdata)
+ del data[:]
+ del subdata[:]
+ if data:
+ items.insert().execute(*data)
+ if subdata:
+ subitems.insert().execute(*subdata)
+ transaction.commit()
+
+ transaction = con.begin()
+ data = []
+ for customer_id in xrange(1, q_customers):
+ data.append({'name': "customer number %s" % customer_id})
+ if customer_id % 100 == 0:
+ customers.insert().execute(*data)
+ del data[:]
+ if data:
+ customers.insert().execute(*data)
+ transaction.commit()
+
+ transaction = con.begin()
+ data, subdata = [], []
+ order_t = int(time.time()) - (5000 * 5 * 60)
+ current = xrange(1, q_customers)
+ step, purchase_id = 1, 0
+ while current:
+ next = []
+ for customer_id in current:
+ order_t += 300
+ data.append({'customer_id': customer_id,
+ 'purchase_date': datetime.fromtimestamp(order_t)})
+ purchase_id += 1
+ for item_id in range(customer_id % 200, customer_id + 1, 200):
+ if item_id != 0:
+ subdata.append({'purchase_id': purchase_id,
+ 'item_id': item_id})
+ if customer_id % 10 > step:
+ next.append(customer_id)
+
+ if len(data) >= 100:
+ purchases.insert().execute(*data)
+ if subdata:
+ purchaseitems.insert().execute(*subdata)
+ del data[:]
+ del subdata[:]
+ step, current = step + 1, next
+
+ if data:
+ purchases.insert().execute(*data)
+ if subdata:
+ purchaseitems.insert().execute(*subdata)
+ transaction.commit()
+
+@profiled('queries')
+def run_queries():
+ session = create_session()
+ # no explicit transaction here.
+
+ # build a report of summarizing the last 50 purchases and
+ # the top 20 items from all purchases
+
+ q = session.query(Purchase). \
+ limit(50).order_by(desc(Purchase.purchase_date)). \
+ options(eagerload('items'), eagerload('items.subitems'),
+ eagerload('customer'))
+
+ report = []
+ # "write" the report. pretend it's going to a web template or something,
+ # the point is to actually pull data through attributes and collections.
+ for purchase in q:
+ report.append(purchase.customer.name)
+ report.append(purchase.customer.col_a)
+ report.append(purchase.purchase_date)
+ for item in purchase.items:
+ report.append(item.name)
+ report.extend([s.name for s in item.subitems])
+
+ # mix a little low-level with orm
+ # pull a report of the top 20 items of all time
+ _item_id = purchaseitems.c.item_id
+ top_20_q = select([func.distinct(_item_id).label('id')],
+ group_by=[purchaseitems.c.purchase_id, _item_id],
+ order_by=[desc(func.count(_item_id)), _item_id],
+ limit=20)
+ ids = [r.id for r in top_20_q.execute().fetchall()]
+ q2 = session.query(Item).filter(Item.id.in_(*ids))
+
+ for num, item in enumerate(q2):
+ report.append("number %s: %s" % (num + 1, item.name))
+
+@profiled('creating')
+def create_purchase():
+ # commit a purchase
+ customer_id = 100
+ item_ids = (10,22,34,46,58)
+
+ session = create_session()
+ session.begin()
+
+ customer = session.query(Customer).get(customer_id)
+ items = session.query(Item).filter(Item.id.in_(*item_ids))
+
+ purchase = Purchase()
+ purchase.customer = customer
+ purchase.items.extend(items)
+
+ session.flush()
+ session.commit()
+ session.expire(customer)
+
+def setup_db():
+ metadata.drop_all()
+ metadata.create_all()
+def cleanup_db():
+ metadata.drop_all()
+
+@profiled('default')
+def default():
+ run_queries()
+ create_purchase()
+
+@profiled('all')
+def main():
+ metadata.bind = testbase.db
+ try:
+ define_tables()
+ setup_mappers()
+ setup_db()
+ insert_data()
+ default()
+ finally:
+ cleanup_db()
+
+main()
diff --git a/test/perf/poolload.py b/test/perf/poolload.py
index d096f1c67..1a2ff6978 100644
--- a/test/perf/poolload.py
+++ b/test/perf/poolload.py
@@ -1,10 +1,11 @@
# load test of connection pool
+import testbase
from sqlalchemy import *
import sqlalchemy.pool as pool
import thread,time
-db = create_engine('mysql://scott:tiger@127.0.0.1/test', pool_timeout=30, echo_pool=True)
+db = create_engine(testbase.db.url, pool_timeout=30, echo_pool=True)
metadata = MetaData(db)
users_table = Table('users', metadata,
@@ -18,7 +19,7 @@ users_table.insert().execute([{'user_name':'user#%d' % i, 'password':'pw#%d' % i
def runfast():
while True:
- c = db.connection_provider._pool.connect()
+ c = db.pool.connect()
time.sleep(.5)
c.close()
# result = users_table.select(limit=100).execute()
diff --git a/test/perf/threaded_compile.py b/test/perf/threaded_compile.py
index eb9e2f669..13ec31fd6 100644
--- a/test/perf/threaded_compile.py
+++ b/test/perf/threaded_compile.py
@@ -2,9 +2,12 @@
when additional mappers are created while the existing
collection is being compiled."""
+import testbase
from sqlalchemy import *
+from sqlalchemy.orm import *
import thread, time
from sqlalchemy.orm import mapperlib
+from testlib import *
meta = MetaData('sqlite:///foo.db')
diff --git a/test/perf/wsgi.py b/test/perf/wsgi.py
index 365956dc7..d22eeb76a 100644
--- a/test/perf/wsgi.py
+++ b/test/perf/wsgi.py
@@ -1,53 +1,55 @@
#!/usr/bin/python
+"""Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop."""
+import testbase
from sqlalchemy import *
-import sqlalchemy.pool as pool
+from sqlalchemy.orm import *
import thread
-from sqlalchemy import exceptions
+from testlib import *
+
+port = 8000
import logging
logging.basicConfig()
logging.getLogger('sqlalchemy.pool').setLevel(logging.INFO)
threadids = set()
-#meta = MetaData('postgres://scott:tiger@127.0.0.1/test')
-
-#meta = MetaData('mysql://scott:tiger@localhost/test', poolclass=pool.SingletonThreadPool)
-meta = MetaData('mysql://scott:tiger@localhost/test')
+meta = MetaData(testbase.db)
foo = Table('foo', meta,
Column('id', Integer, primary_key=True),
Column('data', String(30)))
-
-meta.drop_all()
-meta.create_all()
-
-data = []
-for x in range(1,500):
- data.append({'id':x,'data':"this is x value %d" % x})
-foo.insert().execute(data)
-
class Foo(object):
pass
-
mapper(Foo, foo)
-root = './'
-port = 8000
+def prep():
+ meta.drop_all()
+ meta.create_all()
+
+ data = []
+ for x in range(1,500):
+ data.append({'id':x,'data':"this is x value %d" % x})
+ foo.insert().execute(data)
def serve(environ, start_response):
+ start_response("200 OK", [('Content-type', 'text/plain')])
sess = create_session()
l = sess.query(Foo).select()
-
- start_response("200 OK", [('Content-type','text/plain')])
threadids.add(thread.get_ident())
- print "sending response on thread", thread.get_ident(), " total threads ", len(threadids)
- return ["\n".join([x.data for x in l])]
+
+ print ("sending response on thread", thread.get_ident(),
+ " total threads ", len(threadids))
+ return [str("\n".join([x.data for x in l]))]
if __name__ == '__main__':
- from wsgiutils import wsgiServer
- server = wsgiServer.WSGIServer (('localhost', port), {'/': serve})
- print "Server listening on port %d" % port
- server.serve_forever()
+ from wsgiref import simple_server
+ try:
+ prep()
+ server = simple_server.make_server('localhost', port, serve)
+ print "Server listening on port %d" % port
+ server.serve_forever()
+ finally:
+ meta.drop_all()
diff --git a/test/rundocs.py b/test/rundocs.py
deleted file mode 100644
index 1918e55be..000000000
--- a/test/rundocs.py
+++ /dev/null
@@ -1,242 +0,0 @@
-from sqlalchemy import *
-import sys
-sys.path.insert(0, './lib/')
-
-engine = create_engine('sqlite://')
-
-engine.echo = True
-
-# table metadata
-users = Table('users', engine,
- Column('user_id', Integer, primary_key = True),
- Column('user_name', String(16), nullable = False),
- Column('password', String(20), nullable = False)
-)
-users.create()
-users.insert().execute(
- dict(user_name = 'fred', password='45nfss')
-)
-
-
-# class definition
-class User(object):
- pass
-assign_mapper(User, users)
-
-# select
-user = User.get_by(user_name = 'fred')
-
-# modify
-user.user_name = 'fred jones'
-
-# commit
-objectstore.commit()
-
-objectstore.clear()
-
-
-
-addresses = Table('email_addresses', engine,
- Column('address_id', Integer, primary_key = True),
- Column('user_id', Integer, ForeignKey(users.c.user_id)),
- Column('email_address', String(20)),
-)
-addresses.create()
-addresses.insert().execute(
- dict(user_id = user.user_id, email_address='fred@bar.com')
-)
-
-# second class definition
-class Address(object):
- def __init__(self, email_address = None):
- self.email_address = email_address
-
- mapper = assignmapper(addresses)
-
-# obtain a Mapper. "private=True" means deletions of the user
-# will cascade down to the child Address objects
-User.mapper = assignmapper(users, properties = dict(
- addresses = relation(Address.mapper, lazy=True, private=True)
-))
-
-# select
-user = User.mapper.select(User.c.user_name == 'fred jones')[0]
-address = user.addresses[0]
-
-# modify
-user.user_name = 'fred'
-user.addresses[0].email_address = 'fredjones@foo.com'
-user.addresses.append(Address('freddy@hi.org'))
-
-# commit
-objectstore.commit()
-
-# going to change tables, etc., start over with a new engine
-objectstore.clear()
-engine = None
-engine = sqlite.engine(':memory:', {})
-engine.echo = True
-
-# a table to store a user's preferences for a site
-prefs = Table('user_prefs', engine,
- Column('pref_id', Integer, primary_key = True),
- Column('stylename', String(20)),
- Column('save_password', Boolean, nullable = False),
- Column('timezone', CHAR(3), nullable = False)
-)
-prefs.create()
-prefs.insert().execute(
- dict(pref_id=1, stylename='green', save_password=1, timezone='EST')
-)
-
-# user table gets 'preference_id' column added
-users = Table('users', engine,
- Column('user_id', Integer, primary_key = True),
- Column('user_name', String(16), nullable = False),
- Column('password', String(20), nullable = False),
- Column('preference_id', Integer, ForeignKey(prefs.c.pref_id))
-)
-users.drop()
-users.create()
-users.insert().execute(
- dict(user_name = 'fred', password='45nfss', preference_id=1)
-)
-
-
-addresses = Table('email_addresses', engine,
- Column('address_id', Integer, primary_key = True),
- Column('user_id', Integer, ForeignKey(users.c.user_id)),
- Column('email_address', String(20)),
-)
-addresses.drop()
-addresses.create()
-
-Address.mapper = assignmapper(addresses)
-
-# class definition for preferences
-class UserPrefs(object):
- mapper = assignmapper(prefs)
-
-# set a new Mapper on the user
-User.mapper = assignmapper(users, properties = dict(
- addresses = relation(Address.mapper, lazy=True, private=True),
- preferences = relation(UserPrefs.mapper, lazy=False, private=True),
-))
-
-# select
-user = User.mapper.select(User.c.user_name == 'fred')[0]
-save_password = user.preferences.save_password
-
-# modify
-user.preferences.stylename = 'bluesteel'
-user.addresses.append(Address('freddy@hi.org'))
-
-# commit
-objectstore.commit()
-
-
-
-articles = Table('articles', engine,
- Column('article_id', Integer, primary_key = True),
- Column('article_headline', String(150), key='headline'),
- Column('article_body', CLOB, key='body'),
-)
-
-keywords = Table('keywords', engine,
- Column('keyword_id', Integer, primary_key = True),
- Column('name', String(50))
-)
-
-itemkeywords = Table('article_keywords', engine,
- Column('article_id', Integer, ForeignKey(articles.c.article_id)),
- Column('keyword_id', Integer, ForeignKey(keywords.c.keyword_id))
-)
-
-articles.create()
-keywords.create()
-itemkeywords.create()
-
-# class definitions
-class Keyword(object):
- def __init__(self, name = None):
- self.name = name
- mapper = assignmapper(keywords)
-
-class Article(object):
- def __init__(self):
- self.keywords = []
- mapper = assignmapper(articles, properties = dict(
- keywords = relation(Keyword.mapper, itemkeywords, lazy=False)
- ))
-Article.mapper
-
-article = Article()
-article.headline = 'a headline'
-article.body = 'this is the body'
-article.keywords.append(Keyword('politics'))
-article.keywords.append(Keyword('entertainment'))
-objectstore.commit()
-
-# select articles based on some keywords. the extra selection criterion
-# won't get in the way of the separate eager load of all the article's keywords
-alist = Article.mapper.select(sql.and_(
- keywords.c.keyword_id==itemkeywords.c.keyword_id,
- itemkeywords.c.article_id==articles.c.article_id,
- keywords.c.name.in_('politics', 'entertainment')))
-
-# modify
-a = alist[0]
-del a.keywords[:]
-a.keywords.append(Keyword('topstories'))
-a.keywords.append(Keyword('government'))
-
-# commit. individual INSERT/DELETE operations will take place only for the list
-# elements that changed.
-objectstore.commit()
-
-
-clear_mappers()
-itemkeywords.drop()
-itemkeywords = Table('article_keywords', engine,
- Column('article_id', Integer, ForeignKey("articles.article_id")),
- Column('keyword_id', Integer, ForeignKey("keywords.keyword_id")),
- Column('attached_by', Integer, ForeignKey("users.user_id"))
-, redefine=True)
-itemkeywords.create()
-
-# define an association class
-class KeywordAssociation(object):pass
-
-# define the mapper. when we load an article, we always want to get the keywords via
-# eager loading. but the user who added each keyword, we usually dont need so specify
-# lazy loading for that.
-m = mapper(Article, articles, properties=dict(
- keywords = relation(KeywordAssociation, itemkeywords, lazy = False,
- primary_key=[itemkeywords.c.article_id, itemkeywords.c.keyword_id],
- properties=dict(
- keyword = relation(Keyword, keywords, lazy = False),
- user = relation(User, users, lazy = True)
- )
- )
- )
-)
-
-# bonus step - well, we do want to load the users in one shot,
-# so modify the mapper via an option.
-# this returns a new mapper with the option switched on.
-m2 = m.options(eagerload('keywords.user'))
-
-# select by keyword again
-alist = m2.select(
- sql.and_(
- keywords.c.keyword_id==itemkeywords.c.keyword_id,
- itemkeywords.c.article_id==articles.c.article_id,
- keywords.c.name == 'jacks_stories'
- ))
-
-# user is available
-for a in alist:
- for k in a.keywords:
- if k.keyword.name == 'jacks_stories':
- print k.user.user_name
-
diff --git a/test/sql/alltests.py b/test/sql/alltests.py
index 7be1a3ffb..a669a25f2 100644
--- a/test/sql/alltests.py
+++ b/test/sql/alltests.py
@@ -7,6 +7,8 @@ def suite():
'sql.testtypes',
'sql.constraints',
+ 'sql.generative',
+
# SQL syntax
'sql.select',
'sql.selectable',
@@ -30,7 +32,5 @@ def suite():
alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
return alltests
-
-
if __name__ == '__main__':
testbase.main(suite())
diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py
index 946279b9d..493545b22 100644
--- a/test/sql/case_statement.py
+++ b/test/sql/case_statement.py
@@ -1,13 +1,15 @@
-import sys
import testbase
+import sys
from sqlalchemy import *
+from testlib import *
-class CaseTest(testbase.PersistTest):
+class CaseTest(PersistTest):
def setUpAll(self):
+ metadata = MetaData(testbase.db)
global info_table
- info_table = Table('infos', testbase.db,
+ info_table = Table('infos', metadata,
Column('pk', Integer, primary_key=True),
Column('info', String(30)))
@@ -26,9 +28,9 @@ class CaseTest(testbase.PersistTest):
def testcase(self):
inner = select([case([
[info_table.c.pk < 3,
- literal('lessthan3', type=String)],
+ literal('lessthan3', type_=String)],
[and_(info_table.c.pk >= 3, info_table.c.pk < 7),
- literal('gt3', type=String)]]).label('x'),
+ literal('gt3', type_=String)]]).label('x'),
info_table.c.pk, info_table.c.info],
from_obj=[info_table]).alias('q_inner')
@@ -65,9 +67,9 @@ class CaseTest(testbase.PersistTest):
w_else = select([case([
[info_table.c.pk < 3,
- literal(3, type=Integer)],
+ literal(3, type_=Integer)],
[and_(info_table.c.pk >= 3, info_table.c.pk < 6),
- literal(6, type=Integer)]],
+ literal(6, type_=Integer)]],
else_ = 0).label('x'),
info_table.c.pk, info_table.c.info],
from_obj=[info_table]).alias('q_inner')
diff --git a/test/sql/constraints.py b/test/sql/constraints.py
index 7e1172850..3120185d5 100644
--- a/test/sql/constraints.py
+++ b/test/sql/constraints.py
@@ -1,8 +1,8 @@
import testbase
from sqlalchemy import *
-import sys
+from testlib import *
-class ConstraintTest(testbase.AssertMixin):
+class ConstraintTest(AssertMixin):
def setUp(self):
global metadata
@@ -52,7 +52,7 @@ class ConstraintTest(testbase.AssertMixin):
)
metadata.create_all()
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_check_constraint(self):
foo = Table('foo', metadata,
Column('id', Integer, primary_key=True),
@@ -172,12 +172,13 @@ class ConstraintTest(testbase.AssertMixin):
capt = []
connection = testbase.db.connect()
- ex = connection._execute
+ # TODO: hacky, put a real connection proxy in
+ ex = connection._Connection__execute
def proxy(context):
capt.append(context.statement)
capt.append(repr(context.parameters))
ex(context)
- connection._execute = proxy
+ connection._Connection__execute = proxy
schemagen = testbase.db.dialect.schemagenerator(connection)
schemagen.traverse(events)
diff --git a/test/sql/defaults.py b/test/sql/defaults.py
index 10a3610f9..6c200232f 100644
--- a/test/sql/defaults.py
+++ b/test/sql/defaults.py
@@ -1,52 +1,59 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest, sys, os
-import sqlalchemy.schema as schema
import testbase
from sqlalchemy import *
-import sqlalchemy
-
-db = testbase.db
+import sqlalchemy.util as util
+import sqlalchemy.schema as schema
+from sqlalchemy.orm import mapper, create_session
+from testlib import *
+import datetime
class DefaultTest(PersistTest):
def setUpAll(self):
- global t, f, f2, ts, currenttime
+ global t, f, f2, ts, currenttime, metadata
+
+ db = testbase.db
+ metadata = MetaData(db)
x = {'x':50}
def mydefault():
x['x'] += 1
return x['x']
+ def mydefault_with_ctx(ctx):
+ return ctx.compiled_parameters['col1'] + 10
+
+ def myupdate_with_ctx(ctx):
+ return len(ctx.compiled_parameters['col2'])
+
use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
is_oracle = db.engine.name == 'oracle'
# select "count(1)" returns different results on different DBs
# also correct for "current_date" compatible as column default, value differences
- currenttime = func.current_date(type=Date, engine=db);
+ currenttime = func.current_date(type_=Date, bind=db);
if is_oracle:
ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar()
- f = select([func.count(1) + 5], engine=db).scalar()
- f2 = select([func.count(1) + 14], engine=db).scalar()
+ f = select([func.count(1) + 5], bind=db).scalar()
+ f2 = select([func.count(1) + 14], bind=db).scalar()
# TODO: engine propigation across nested functions not working
- currenttime = func.trunc(currenttime, literal_column("'DAY'"), engine=db)
+ currenttime = func.trunc(currenttime, literal_column("'DAY'"), bind=db)
def1 = currenttime
def2 = func.trunc(text("sysdate"), literal_column("'DAY'"))
deftype = Date
elif use_function_defaults:
- f = select([func.count(1) + 5], engine=db).scalar()
- f2 = select([func.count(1) + 14], engine=db).scalar()
+ f = select([func.count(1) + 5], bind=db).scalar()
+ f2 = select([func.count(1) + 14], bind=db).scalar()
def1 = currenttime
def2 = text("current_date")
deftype = Date
ts = db.func.current_date().scalar()
else:
- f = select([func.count(1) + 5], engine=db).scalar()
- f2 = select([func.count(1) + 14], engine=db).scalar()
+ f = select([func.count(1) + 5], bind=db).scalar()
+ f2 = select([func.count(1) + 14], bind=db).scalar()
def1 = def2 = "3"
ts = 3
deftype = Integer
- t = Table('default_test1', db,
+ t = Table('default_test1', metadata,
# python function
Column('col1', Integer, primary_key=True, default=mydefault),
@@ -66,7 +73,13 @@ class DefaultTest(PersistTest):
Column('col6', Date, default=currenttime, onupdate=currenttime),
Column('boolcol1', Boolean, default=True),
- Column('boolcol2', Boolean, default=False)
+ Column('boolcol2', Boolean, default=False),
+
+ # python function which uses ExecutionContext
+ Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx),
+
+ # python builtin
+ Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today)
)
t.create()
@@ -75,9 +88,18 @@ class DefaultTest(PersistTest):
def tearDown(self):
t.delete().execute()
-
+
+ def testargsignature(self):
+ def mydefault(x, y):
+ pass
+ try:
+ c = ColumnDefault(mydefault)
+ assert False
+ except exceptions.ArgumentError, e:
+ assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e)
+
def teststandalone(self):
- c = db.engine.contextual_connect()
+ c = testbase.db.engine.contextual_connect()
x = c.execute(t.c.col1.default)
y = t.c.col2.default.execute()
z = c.execute(t.c.col3.default)
@@ -94,9 +116,10 @@ class DefaultTest(PersistTest):
t.insert().execute()
ctexec = currenttime.scalar()
- self.echo("Currenttime "+ repr(ctexec))
+ print "Currenttime "+ repr(ctexec)
l = t.select().execute()
- self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)])
+ today = datetime.date.today()
+ self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)])
def testinsertvalues(self):
t.insert(values={'col3':50}).execute()
@@ -109,10 +132,10 @@ class DefaultTest(PersistTest):
pk = r.last_inserted_ids()[0]
t.update(t.c.col1==pk).execute(col4=None, col5=None)
ctexec = currenttime.scalar()
- self.echo("Currenttime "+ repr(ctexec))
+ print "Currenttime "+ repr(ctexec)
l = t.select(t.c.col1==pk).execute()
l = l.fetchone()
- self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False))
+ self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today()))
# mysql/other db's return 0 or 1 for count(1)
self.assert_(14 <= f2 <= 15)
@@ -124,8 +147,35 @@ class DefaultTest(PersistTest):
l = l.fetchone()
self.assert_(l['col3'] == 55)
+ @testing.supported('postgres')
+ def testpassiveoverride(self):
+ """primarily for postgres, tests that when we get a primary key column back
+ from reflecting a table which has a default value on it, we pre-execute
+ that PassiveDefault upon insert, even though PassiveDefault says
+ "let the database execute this", because in postgres we must have all the primary
+ key values in memory before insert; otherwise we cant locate the just inserted row."""
+
+ try:
+ meta = MetaData(testbase.db)
+ testbase.db.execute("""
+ CREATE TABLE speedy_users
+ (
+ speedy_user_id SERIAL PRIMARY KEY,
+
+ user_name VARCHAR NOT NULL,
+ user_password VARCHAR NOT NULL
+ );
+ """, None)
+
+ t = Table("speedy_users", meta, autoload=True)
+ t.insert().execute(user_name='user', user_password='lala')
+ l = t.select().execute().fetchall()
+ self.assert_(l == [(1, 'user', 'lala')])
+ finally:
+ testbase.db.execute("drop table speedy_users", None)
+
class AutoIncrementTest(PersistTest):
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def testnonautoincrement(self):
meta = MetaData(testbase.db)
nonai_table = Table("aitest", meta,
@@ -159,6 +209,9 @@ class AutoIncrementTest(PersistTest):
table.drop()
def testfetchid(self):
+
+ # TODO: what does this test do that all the various ORM tests dont ?
+
meta = MetaData(testbase.db)
table = Table("aitest", meta,
Column('id', Integer, primary_key=True),
@@ -186,7 +239,7 @@ class AutoIncrementTest(PersistTest):
class SequenceTest(PersistTest):
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def setUpAll(self):
global cartitems, sometable, metadata
metadata = MetaData(testbase.db)
@@ -197,13 +250,13 @@ class SequenceTest(PersistTest):
)
sometable = Table( 'Manager', metadata,
Column( 'obj_id', Integer, Sequence('obj_id_seq'), ),
- Column( 'name', type= String, ),
+ Column( 'name', String, ),
Column( 'id', Integer, primary_key= True, ),
)
metadata.create_all()
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def testseqnonpk(self):
"""test sequences fire off as defaults on non-pk columns"""
sometable.insert().execute(name="somename")
@@ -213,7 +266,7 @@ class SequenceTest(PersistTest):
(2, "someother", 2),
]
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def testsequence(self):
cartitems.insert().execute(description='hi')
cartitems.insert().execute(description='there')
@@ -222,8 +275,8 @@ class SequenceTest(PersistTest):
cartitems.select().execute().fetchall()
- @testbase.supported('postgres', 'oracle')
- def teststandalone(self):
+ @testing.supported('postgres', 'oracle')
+ def test_implicit_sequence_exec(self):
s = Sequence("my_sequence", metadata=MetaData(testbase.db))
s.create()
try:
@@ -232,7 +285,7 @@ class SequenceTest(PersistTest):
finally:
s.drop()
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def teststandalone_explicit(self):
s = Sequence("my_sequence")
s.create(bind=testbase.db)
@@ -242,12 +295,20 @@ class SequenceTest(PersistTest):
finally:
s.drop(testbase.db)
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
+ def test_checkfirst(self):
+ s = Sequence("my_sequence")
+ s.create(testbase.db, checkfirst=False)
+ s.create(testbase.db, checkfirst=True)
+ s.drop(testbase.db, checkfirst=False)
+ s.drop(testbase.db, checkfirst=True)
+
+ @testing.supported('postgres', 'oracle')
def teststandalone2(self):
x = cartitems.c.cart_id.sequence.execute()
self.assert_(1 <= x <= 4)
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def tearDownAll(self):
metadata.drop_all()
diff --git a/test/sql/generative.py b/test/sql/generative.py
new file mode 100644
index 000000000..357a66fcd
--- /dev/null
+++ b/test/sql/generative.py
@@ -0,0 +1,275 @@
+import testbase
+from sql import select as selecttests
+from sqlalchemy import *
+from testlib import *
+
+class TraversalTest(AssertMixin):
+ """test ClauseVisitor's traversal, particularly its ability to copy and modify
+ a ClauseElement in place."""
+
+ def setUpAll(self):
+ global A, B
+
+ # establish two ficticious ClauseElements.
+ # define deep equality semantics as well as deep identity semantics.
+ class A(ClauseElement):
+ def __init__(self, expr):
+ self.expr = expr
+
+ def is_other(self, other):
+ return other is self
+
+ def __eq__(self, other):
+ return other.expr == self.expr
+
+ def __ne__(self, other):
+ return other.expr != self.expr
+
+ def __str__(self):
+ return "A(%s)" % repr(self.expr)
+
+ class B(ClauseElement):
+ def __init__(self, *items):
+ self.items = items
+
+ def is_other(self, other):
+ if other is not self:
+ return False
+ for i1, i2 in zip(self.items, other.items):
+ if i1 is not i2:
+ return False
+ return True
+
+ def __eq__(self, other):
+ for i1, i2 in zip(self.items, other.items):
+ if i1 != i2:
+ return False
+ return True
+
+ def __ne__(self, other):
+ for i1, i2 in zip(self.items, other.items):
+ if i1 != i2:
+ return True
+ return False
+
+ def _copy_internals(self):
+ self.items = [i._clone() for i in self.items]
+
+ def get_children(self, **kwargs):
+ return self.items
+
+ def __str__(self):
+ return "B(%s)" % repr([str(i) for i in self.items])
+
+ def test_test_classes(self):
+ a1 = A("expr1")
+ struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct3 = B(a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+ assert a1.is_other(a1)
+ assert struct.is_other(struct)
+ assert struct == struct2
+ assert struct != struct3
+ assert not struct.is_other(struct2)
+ assert not struct.is_other(struct3)
+
+ def test_clone(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ pass
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=True)
+ assert struct == s2
+ assert not struct.is_other(s2)
+
+ def test_no_clone(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ pass
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=False)
+ assert struct == s2
+ assert struct.is_other(s2)
+
+ def test_change_in_place(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ if a.expr == "expr2":
+ a.expr = "expr2modified"
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=True)
+ assert struct != s2
+ assert not struct.is_other(s2)
+ assert struct2 == s2
+
+ class Vis2(ClauseVisitor):
+ def visit_a(self, a):
+ if a.expr == "expr2b":
+ a.expr = "expr2bmodified"
+ def visit_b(self, b):
+ pass
+
+ vis2 = Vis2()
+ s3 = vis2.traverse(struct, clone=True)
+ assert struct != s3
+ assert struct3 == s3
+
+class ClauseTest(selecttests.SQLTest):
+ """test copy-in-place behavior of various ClauseElements."""
+
+ def setUpAll(self):
+ global t1, t2
+ t1 = table("table1",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+ t2 = table("table2",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+
+ def test_binary(self):
+ clause = t1.c.col2 == t2.c.col2
+ assert str(clause) == ClauseVisitor().traverse(clause, clone=True)
+
+ def test_join(self):
+ clause = t1.join(t2, t1.c.col2==t2.c.col2)
+ c1 = str(clause)
+ assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True))
+
+ class Vis(ClauseVisitor):
+ def visit_binary(self, binary):
+ binary.right = t2.c.col3
+
+ clause2 = Vis().traverse(clause, clone=True)
+ assert c1 == str(clause)
+ assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
+
+ def test_select(self):
+ s = t1.select()
+ s2 = select([s])
+ s2_assert = str(s2)
+ s3_assert = str(select([t1.select()], t1.c.col2==7))
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col2==7)
+ s3 = Vis().traverse(s2, clone=True)
+ assert str(s3) == s3_assert
+ assert str(s2) == s2_assert
+ print str(s2)
+ print str(s3)
+ Vis().traverse(s2)
+ assert str(s2) == s3_assert
+
+ print "------------------"
+
+ s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9)))
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col3==9)
+ s4 = Vis().traverse(s3, clone=True)
+ print str(s3)
+ print str(s4)
+ assert str(s4) == s4_assert
+ assert str(s3) == s3_assert
+
+ print "------------------"
+ s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9)))
+ class Vis(ClauseVisitor):
+ def visit_binary(self, binary):
+ if binary.left is t1.c.col3:
+ binary.left = t1.c.col1
+ binary.right = bindparam("table1_col1")
+ s5 = Vis().traverse(s4, clone=True)
+ print str(s4)
+ print str(s5)
+ assert str(s5) == s5_assert
+ assert str(s4) == s4_assert
+
+ def test_correlated_select(self):
+ s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col2==7)
+
+ self.runtest(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2")
+
+ def test_clause_adapter(self):
+ from sqlalchemy import sql_util
+
+ t1alias = t1.alias('t1alias')
+
+ vis = sql_util.ClauseAdapter(t1alias)
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
+ assert ff._get_from_objects() == [t1alias]
+
+ self.runtest(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
+
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
+ self.runtest(ff, "count(t1alias.col1) AS foo")
+ assert ff._get_from_objects() == [t1alias]
+
+# TODO:
+# self.runtest(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
+
+ t2alias = t2.alias('t2alias')
+ vis.chain(sql_util.ClauseAdapter(t2alias))
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+
+
+
+class SelectTest(selecttests.SQLTest):
+ """tests the generative capability of Select"""
+
+ def setUpAll(self):
+ global t1, t2
+ t1 = table("table1",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+ t2 = table("table2",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+
+ def test_select(self):
+ self.runtest(t1.select().where(t1.c.col1==5).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1 WHERE table1.col1 = :table1_col1 ORDER BY table1.col3")
+
+ self.runtest(t1.select().select_from(select([t2], t2.c.col1==t1.c.col1)).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1) ORDER BY table1.col3")
+
+ s = select([t2], t2.c.col1==t1.c.col1, correlate=False)
+ s = s.correlate(t1).order_by(t2.c.col3)
+ self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3")
+
+
+if __name__ == '__main__':
+ testbase.main()
diff --git a/test/sql/labels.py b/test/sql/labels.py
index ee9fa6bc5..553a3a3bc 100644
--- a/test/sql/labels.py
+++ b/test/sql/labels.py
@@ -1,11 +1,12 @@
import testbase
-
from sqlalchemy import *
+from testlib import *
+
# TODO: either create a mock dialect with named paramstyle and a short identifier length,
# or find a way to just use sqlite dialect and make those changes
-class LabelTypeTest(testbase.PersistTest):
+class LabelTypeTest(PersistTest):
def test_type(self):
m = MetaData()
t = Table('sometable', m,
@@ -14,21 +15,26 @@ class LabelTypeTest(testbase.PersistTest):
assert isinstance(t.c.col1.label('hi').type, Integer)
assert isinstance(select([t.c.col2], scalar=True).label('lala').type, Float)
-class LongLabelsTest(testbase.PersistTest):
+class LongLabelsTest(PersistTest):
def setUpAll(self):
- global metadata, table1
- metadata = MetaData(engine=testbase.db)
+ global metadata, table1, maxlen
+ metadata = MetaData(testbase.db)
table1 = Table("some_large_named_table", metadata,
Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True),
Column("this_is_the_data_column", String(30))
)
metadata.create_all()
+
+ maxlen = testbase.db.dialect.max_identifier_length
+ testbase.db.dialect.max_identifier_length = lambda: 29
+
def tearDown(self):
table1.delete().execute()
def tearDownAll(self):
metadata.drop_all()
+ testbase.db.dialect.max_identifier_length = maxlen
def test_result(self):
table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
@@ -88,7 +94,7 @@ class LongLabelsTest(testbase.PersistTest):
x = select([tt], use_labels=True, order_by=tt.oid_column).compile(dialect=dialect)
#print x
# assert it doesnt end with "ORDER BY foo.some_large_named_table_this_is_the_primarykey_column"
- assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_1""")
+ assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_2""")
if __name__ == '__main__':
testbase.main()
diff --git a/test/sql/query.py b/test/sql/query.py
index 8af5aafea..48a28a9a5 100644
--- a/test/sql/query.py
+++ b/test/sql/query.py
@@ -1,13 +1,9 @@
-from testbase import PersistTest
import testbase
-import unittest, sys, datetime
-
-import sqlalchemy.databases.sqlite as sqllite
-
-import tables
+import datetime
from sqlalchemy import *
-from sqlalchemy.engine import ResultProxy, RowProxy
from sqlalchemy import exceptions
+from testlib import *
+
class QueryTest(PersistTest):
@@ -24,25 +20,24 @@ class QueryTest(PersistTest):
Column('address', String(30)))
metadata.create_all()
- def setUp(self):
- self.users = users
def tearDown(self):
- self.users.delete().execute()
+ addresses.delete().execute()
+ users.delete().execute()
def tearDownAll(self):
metadata.drop_all()
def testinsert(self):
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- print repr(self.users.select().execute().fetchall())
-
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ assert users.count().scalar() == 1
+
def testupdate(self):
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- print repr(self.users.select().execute().fetchall())
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ assert users.count().scalar() == 1
- self.users.update(self.users.c.user_id == 7).execute(user_name = 'fred')
- print repr(self.users.select().execute().fetchall())
+ users.update(users.c.user_id == 7).execute(user_name = 'fred')
+ assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred'
def test_lastrow_accessor(self):
"""test the last_inserted_ids() and lastrow_has_id() functions"""
@@ -63,14 +58,15 @@ class QueryTest(PersistTest):
if result.lastrow_has_defaults():
criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
row = table.select(criterion).execute().fetchone()
- ret.update(row)
+ for c in table.c:
+ ret[c.key] = row[c]
return ret
for supported, table, values, assertvalues in [
(
{'unsupported':['sqlite']},
Table("t1", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
Column('foo', String(30), primary_key=True)),
{'foo':'hi'},
{'id':1, 'foo':'hi'}
@@ -78,7 +74,7 @@ class QueryTest(PersistTest):
(
{'unsupported':['sqlite']},
Table("t2", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
Column('foo', String(30), primary_key=True),
Column('bar', String(30), PassiveDefault('hi'))
),
@@ -98,7 +94,7 @@ class QueryTest(PersistTest):
(
{'unsupported':[]},
Table("t4", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
Column('foo', String(30), primary_key=True),
Column('bar', String(30), PassiveDefault('hi'))
),
@@ -124,109 +120,94 @@ class QueryTest(PersistTest):
table.drop()
def testrowiteration(self):
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- self.users.insert().execute(user_id = 8, user_name = 'ed')
- self.users.insert().execute(user_id = 9, user_name = 'fred')
- r = self.users.select().execute()
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'ed')
+ users.insert().execute(user_id = 9, user_name = 'fred')
+ r = users.select().execute()
l = []
for row in r:
l.append(row)
self.assert_(len(l) == 3)
def test_fetchmany(self):
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- self.users.insert().execute(user_id = 8, user_name = 'ed')
- self.users.insert().execute(user_id = 9, user_name = 'fred')
- r = self.users.select().execute()
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'ed')
+ users.insert().execute(user_id = 9, user_name = 'fred')
+ r = users.select().execute()
l = []
for row in r.fetchmany(size=2):
l.append(row)
self.assert_(len(l) == 2, "fetchmany(size=2) got %s rows" % len(l))
def test_compiled_execute(self):
- s = select([self.users], self.users.c.user_id==bindparam('id')).compile()
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ s = select([users], users.c.user_id==bindparam('id')).compile()
c = testbase.db.connect()
- print repr(c.execute(s, id=7).fetchall())
-
- def test_global_metadata(self):
- t1 = Table('table1', Column('col1', Integer, primary_key=True),
- Column('col2', String(20)))
- t2 = Table('table2', Column('col1', Integer, primary_key=True),
- Column('col2', String(20)))
-
- assert t1.c.col1
- global_connect(testbase.db)
- default_metadata.create_all()
- try:
- assert t1.count().scalar() == 0
- finally:
- default_metadata.drop_all()
- default_metadata.clear()
-
+ assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7
def test_repeated_bindparams(self):
"""test that a BindParam can be used more than once.
this should be run for dbs with both positional and named paramstyles."""
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- self.users.insert().execute(user_id = 8, user_name = 'fred')
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
u = bindparam('userid')
- s = self.users.select(or_(self.users.c.user_name==u, self.users.c.user_name==u))
+ s = users.select(or_(users.c.user_name==u, users.c.user_name==u))
r = s.execute(userid='fred').fetchall()
assert len(r) == 1
def test_bindparam_shortname(self):
"""test the 'shortname' field on BindParamClause."""
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- self.users.insert().execute(user_id = 8, user_name = 'fred')
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
u = bindparam('userid', shortname='someshortname')
- s = self.users.select(self.users.c.user_name==u)
+ s = users.select(users.c.user_name==u)
r = s.execute(someshortname='fred').fetchall()
assert len(r) == 1
def testdelete(self):
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- self.users.insert().execute(user_id = 8, user_name = 'fred')
- print repr(self.users.select().execute().fetchall())
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+ print repr(users.select().execute().fetchall())
- self.users.delete(self.users.c.user_name == 'fred').execute()
+ users.delete(users.c.user_name == 'fred').execute()
- print repr(self.users.select().execute().fetchall())
+ print repr(users.select().execute().fetchall())
def testselectlimit(self):
- self.users.insert().execute(user_id=1, user_name='john')
- self.users.insert().execute(user_id=2, user_name='jack')
- self.users.insert().execute(user_id=3, user_name='ed')
- self.users.insert().execute(user_id=4, user_name='wendy')
- self.users.insert().execute(user_id=5, user_name='laura')
- self.users.insert().execute(user_id=6, user_name='ralph')
- self.users.insert().execute(user_id=7, user_name='fido')
- r = self.users.select(limit=3, order_by=[self.users.c.user_id]).execute().fetchall()
+ users.insert().execute(user_id=1, user_name='john')
+ users.insert().execute(user_id=2, user_name='jack')
+ users.insert().execute(user_id=3, user_name='ed')
+ users.insert().execute(user_id=4, user_name='wendy')
+ users.insert().execute(user_id=5, user_name='laura')
+ users.insert().execute(user_id=6, user_name='ralph')
+ users.insert().execute(user_id=7, user_name='fido')
+ r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall()
self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r))
- @testbase.unsupported('mssql')
+ @testing.unsupported('mssql')
def testselectlimitoffset(self):
- self.users.insert().execute(user_id=1, user_name='john')
- self.users.insert().execute(user_id=2, user_name='jack')
- self.users.insert().execute(user_id=3, user_name='ed')
- self.users.insert().execute(user_id=4, user_name='wendy')
- self.users.insert().execute(user_id=5, user_name='laura')
- self.users.insert().execute(user_id=6, user_name='ralph')
- self.users.insert().execute(user_id=7, user_name='fido')
- r = self.users.select(limit=3, offset=2, order_by=[self.users.c.user_id]).execute().fetchall()
+ users.insert().execute(user_id=1, user_name='john')
+ users.insert().execute(user_id=2, user_name='jack')
+ users.insert().execute(user_id=3, user_name='ed')
+ users.insert().execute(user_id=4, user_name='wendy')
+ users.insert().execute(user_id=5, user_name='laura')
+ users.insert().execute(user_id=6, user_name='ralph')
+ users.insert().execute(user_id=7, user_name='fido')
+ r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall()
self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')])
- r = self.users.select(offset=5, order_by=[self.users.c.user_id]).execute().fetchall()
+ r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall()
self.assert_(r==[(6, 'ralph'), (7, 'fido')])
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def testselectlimitoffset_mssql(self):
try:
- r = self.users.select(limit=3, offset=2, order_by=[self.users.c.user_id]).execute().fetchall()
+ r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall()
assert False # InvalidRequestError should have been raised
except exceptions.InvalidRequestError:
pass
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_scalar_select(self):
"""test that scalar subqueries with labels get their type propigated to the result set."""
# mysql and/or mysqldb has a bug here, type isnt propigated for scalar subquery.
@@ -244,18 +225,26 @@ class QueryTest(PersistTest):
datetable.drop()
def test_column_accessor(self):
- self.users.insert().execute(user_id=1, user_name='john')
- self.users.insert().execute(user_id=2, user_name='jack')
- r = self.users.select(self.users.c.user_id==2).execute().fetchone()
- self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2)
- self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack')
-
- r = text("select * from query_users where user_id=2", engine=testbase.db).execute().fetchone()
- self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2)
- self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack')
+ users.insert().execute(user_id=1, user_name='john')
+ users.insert().execute(user_id=2, user_name='jack')
+ addresses.insert().execute(address_id=1, user_id=2, address='foo@bar.com')
+
+ r = users.select(users.c.user_id==2).execute().fetchone()
+ self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
+ self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
+
+ r = text("select * from query_users where user_id=2", bind=testbase.db).execute().fetchone()
+ self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
+ self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
+ # test slices
+ r = text("select * from query_addresses", bind=testbase.db).execute().fetchone()
+ self.assert_(r[0:1] == (1,))
+ self.assert_(r[1:] == (2, 'foo@bar.com'))
+ self.assert_(r[:-1] == (1, 2))
+
def test_ambiguous_column(self):
- self.users.insert().execute(user_id=1, user_name='john')
+ users.insert().execute(user_id=1, user_name='john')
r = users.outerjoin(addresses).select().execute().fetchone()
try:
print r['user_id']
@@ -264,18 +253,18 @@ class QueryTest(PersistTest):
assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement."
def test_keys(self):
- self.users.insert().execute(user_id=1, user_name='foo')
- r = self.users.select().execute().fetchone()
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select().execute().fetchone()
self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
def test_items(self):
- self.users.insert().execute(user_id=1, user_name='foo')
- r = self.users.select().execute().fetchone()
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select().execute().fetchone()
self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
def test_len(self):
- self.users.insert().execute(user_id=1, user_name='foo')
- r = self.users.select().execute().fetchone()
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select().execute().fetchone()
self.assertEqual(len(r), 2)
r.close()
r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone()
@@ -295,7 +284,11 @@ class QueryTest(PersistTest):
x = testbase.db.func.current_date().execute().scalar()
y = testbase.db.func.current_date().select().execute().scalar()
z = testbase.db.func.current_date().scalar()
- assert x == y == z
+ assert (x == y == z) is True
+
+ x = testbase.db.func.current_date(type_=Date)
+ assert isinstance(x.type, Date)
+ assert isinstance(x.execute().scalar(), datetime.date)
def test_conn_functions(self):
conn = testbase.db.connect()
@@ -305,8 +298,8 @@ class QueryTest(PersistTest):
z = conn.scalar(func.current_date())
finally:
conn.close()
- assert x == y == z
-
+ assert (x == y == z) is True
+
def test_update_functions(self):
"""test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances,
and that column-level defaults get overridden"""
@@ -357,7 +350,7 @@ class QueryTest(PersistTest):
finally:
meta.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_functions_with_cols(self):
# TODO: shouldnt this work on oracle too ?
x = testbase.db.func.current_date().execute().scalar()
@@ -366,7 +359,7 @@ class QueryTest(PersistTest):
w = select(['*'], from_obj=[testbase.db.func.current_date()]).scalar()
# construct a column-based FROM object out of a function, like in [ticket:172]
- s = select([column('date', type=DateTime)], from_obj=[testbase.db.func.current_date()])
+ s = select([column('date', type_=DateTime)], from_obj=[testbase.db.func.current_date()])
q = s.execute().fetchone()[s.c.date]
r = s.alias('datequery').select().scalar()
@@ -374,8 +367,8 @@ class QueryTest(PersistTest):
def test_column_order_with_simple_query(self):
# should return values in column definition order
- self.users.insert().execute(user_id=1, user_name='foo')
- r = self.users.select(self.users.c.user_id==1).execute().fetchone()
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select(users.c.user_id==1).execute().fetchone()
self.assertEqual(r[0], 1)
self.assertEqual(r[1], 'foo')
self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
@@ -383,14 +376,14 @@ class QueryTest(PersistTest):
def test_column_order_with_text_query(self):
# should return values in query order
- self.users.insert().execute(user_id=1, user_name='foo')
+ users.insert().execute(user_id=1, user_name='foo')
r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone()
self.assertEqual(r[0], 'foo')
self.assertEqual(r[1], 1)
self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
self.assertEqual(r.values(), ['foo', 1])
- @testbase.unsupported('oracle', 'firebird')
+ @testing.unsupported('oracle', 'firebird')
def test_column_accessor_shadow(self):
meta = MetaData(testbase.db)
shadowed = Table('test_shadowed', meta,
@@ -420,7 +413,7 @@ class QueryTest(PersistTest):
finally:
shadowed.drop(checkfirst=True)
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_fetchid_trigger(self):
meta = MetaData(testbase.db)
t1 = Table('t1', meta,
@@ -446,7 +439,7 @@ class QueryTest(PersistTest):
con.execute("""drop trigger paj""")
meta.drop_all()
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_insertid_schema(self):
meta = MetaData(testbase.db)
con = testbase.db.connect()
@@ -459,7 +452,7 @@ class QueryTest(PersistTest):
tbl.drop()
con.execute('drop schema paj')
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_insertid_reserved(self):
meta = MetaData(testbase.db)
table = Table(
@@ -476,51 +469,52 @@ class QueryTest(PersistTest):
def test_in_filtering(self):
- """test the 'shortname' field on BindParamClause."""
- self.users.insert().execute(user_id = 7, user_name = 'jack')
- self.users.insert().execute(user_id = 8, user_name = 'fred')
- self.users.insert().execute(user_id = 9, user_name = None)
+ """test the behavior of the in_() function."""
+
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+ users.insert().execute(user_id = 9, user_name = None)
- s = self.users.select(self.users.c.user_name.in_())
+ s = users.select(users.c.user_name.in_())
r = s.execute().fetchall()
# No username is in empty set
assert len(r) == 0
- s = self.users.select(not_(self.users.c.user_name.in_()))
+ s = users.select(not_(users.c.user_name.in_()))
r = s.execute().fetchall()
# All usernames with a value are outside an empty set
assert len(r) == 2
- s = self.users.select(self.users.c.user_name.in_('jack','fred'))
+ s = users.select(users.c.user_name.in_('jack','fred'))
r = s.execute().fetchall()
assert len(r) == 2
- s = self.users.select(not_(self.users.c.user_name.in_('jack','fred')))
+ s = users.select(not_(users.c.user_name.in_('jack','fred')))
r = s.execute().fetchall()
# Null values are not outside any set
assert len(r) == 0
u = bindparam('search_key')
- s = self.users.select(u.in_())
+ s = users.select(u.in_())
r = s.execute(search_key='john').fetchall()
assert len(r) == 0
r = s.execute(search_key=None).fetchall()
assert len(r) == 0
- s = self.users.select(not_(u.in_()))
+ s = users.select(not_(u.in_()))
r = s.execute(search_key='john').fetchall()
assert len(r) == 3
r = s.execute(search_key=None).fetchall()
assert len(r) == 0
- s = self.users.select(self.users.c.user_name.in_() == True)
+ s = users.select(users.c.user_name.in_() == True)
r = s.execute().fetchall()
assert len(r) == 0
- s = self.users.select(self.users.c.user_name.in_() == False)
+ s = users.select(users.c.user_name.in_() == False)
r = s.execute().fetchall()
assert len(r) == 2
- s = self.users.select(self.users.c.user_name.in_() == None)
+ s = users.select(users.c.user_name.in_() == None)
r = s.execute().fetchall()
assert len(r) == 1
@@ -577,7 +571,7 @@ class CompoundTest(PersistTest):
assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_intersect(self):
i = intersect(
select([t2.c.col3, t2.c.col4]),
@@ -586,7 +580,7 @@ class CompoundTest(PersistTest):
assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
- @testbase.unsupported('mysql', 'oracle')
+ @testing.unsupported('mysql', 'oracle')
def test_except_style1(self):
e = except_(union(
select([t1.c.col3, t1.c.col4]),
@@ -595,7 +589,7 @@ class CompoundTest(PersistTest):
), select([t2.c.col3, t2.c.col4]))
assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
- @testbase.unsupported('mysql', 'oracle')
+ @testing.unsupported('mysql', 'oracle')
def test_except_style2(self):
e = except_(union(
select([t1.c.col3, t1.c.col4]),
@@ -605,7 +599,7 @@ class CompoundTest(PersistTest):
assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
- @testbase.unsupported('sqlite', 'mysql', 'oracle')
+ @testing.unsupported('sqlite', 'mysql', 'oracle')
def test_except_style3(self):
# aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
e = except_(
@@ -617,7 +611,7 @@ class CompoundTest(PersistTest):
)
self.assertEquals(e.execute().fetchall(), [('ccc',)])
- @testbase.unsupported('sqlite', 'mysql', 'oracle')
+ @testing.unsupported('sqlite', 'mysql', 'oracle')
def test_union_union_all(self):
e = union_all(
select([t1.c.col3]),
@@ -628,7 +622,7 @@ class CompoundTest(PersistTest):
)
self.assertEquals(e.execute().fetchall(), [('aaa',),('bbb',),('ccc',),('aaa',),('bbb',),('ccc',)])
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_composite(self):
u = intersect(
select([t2.c.col3, t2.c.col4]),
diff --git a/test/sql/quote.py b/test/sql/quote.py
index bc40d52ee..2fdf9dba0 100644
--- a/test/sql/quote.py
+++ b/test/sql/quote.py
@@ -1,6 +1,7 @@
-from testbase import PersistTest
import testbase
from sqlalchemy import *
+from testlib import *
+
class QuoteTest(PersistTest):
def setUpAll(self):
@@ -78,7 +79,7 @@ class QuoteTest(PersistTest):
assert t1.c.UcCol.case_sensitive is False
assert t2.c.normalcol.case_sensitive is False
- @testbase.unsupported('oracle')
+ @testing.unsupported('oracle')
def testlabels(self):
"""test the quoting of labels.
diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py
index df6a2a883..e0da96a81 100644
--- a/test/sql/rowcount.py
+++ b/test/sql/rowcount.py
@@ -1,7 +1,9 @@
-from sqlalchemy import *
import testbase
+from sqlalchemy import *
+from testlib import *
+
-class FoundRowsTest(testbase.AssertMixin):
+class FoundRowsTest(AssertMixin):
"""tests rowcount functionality"""
def setUpAll(self):
metadata = MetaData(testbase.db)
diff --git a/test/sql/select.py b/test/sql/select.py
index 4d3eb4ad7..a5cf061e2 100644
--- a/test/sql/select.py
+++ b/test/sql/select.py
@@ -1,8 +1,8 @@
-from testbase import PersistTest
import testbase
+import re, operator
from sqlalchemy import *
from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
-import unittest, re, operator
+from testlib import *
# the select test now tests almost completely with TableClause/ColumnClause objects,
@@ -10,21 +10,21 @@ import unittest, re, operator
# so SQLAlchemy's SQL construction engine can be used with no database dependencies at all.
table1 = table('mytable',
- column('myid'),
- column('name'),
- column('description'),
+ column('myid', Integer),
+ column('name', String),
+ column('description', String),
)
table2 = table(
'myothertable',
- column('otherid'),
- column('othername'),
+ column('otherid', Integer),
+ column('othername', String),
)
table3 = table(
'thirdtable',
- column('userid'),
- column('otherstuff'),
+ column('userid', Integer),
+ column('otherstuff', String),
)
metadata = MetaData()
@@ -54,7 +54,7 @@ addresses = table('addresses',
class SQLTest(PersistTest):
def runtest(self, clause, result, dialect = None, params = None, checkparams = None):
c = clause.compile(parameters=params, dialect=dialect)
- self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
+ print "\nSQL String:\n" + str(c) + repr(c.get_params())
cc = re.sub(r'\n', '', str(c))
self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'")
if checkparams is not None:
@@ -130,6 +130,15 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
crit = q.c.myid == table1.c.myid
self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable ORDER BY mytable.myid) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=sqlite.dialect())
self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=mssql.dialect())
+
+ def testmssql_aliases_schemas(self):
+ self.runtest(table4.select(), "SELECT remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM remote_owner.remotetable")
+
+ dialect = mssql.dialect()
+ self.runtest(table4.select(), "SELECT remotetable_1.rem_id, remotetable_1.datatype_id, remotetable_1.value FROM remote_owner.remotetable AS remotetable_1", dialect=dialect)
+
+ # TODO: this is probably incorrect; no "AS <foo>" is being applied to the table
+ self.runtest(table1.join(table4, table1.c.myid==table4.c.rem_id).select(), "SELECT mytable.myid, mytable.name, mytable.description, remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM mytable JOIN remote_owner.remotetable ON remotetable.rem_id = mytable.myid")
def testdontovercorrelate(self):
self.runtest(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""")
@@ -142,6 +151,11 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={})
def testwheresubquery(self):
+ s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s')
+ self.runtest(
+ select([users, s.c.street], from_obj=[s]),
+ """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
+
# TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet.
#self.runtest(
# table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), ""
@@ -194,7 +208,20 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
s = select([table1.c.myid], scalar=True)
self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
-
+
+ s = select([table1.c.myid]).correlate(None).as_scalar()
+ self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
+
+ s = select([table1.c.myid]).as_scalar()
+ self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
+
+ # test expressions against scalar selects
+ self.runtest(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal")
+ self.runtest(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal")
+ self.runtest(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal")
+
+ self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo")
+
zips = table('zips',
column('zipcode'),
@@ -206,15 +233,17 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
column('nm')
)
zip = '12345'
- qlat = select([zips.c.latitude], zips.c.zipcode == zip, scalar=True, correlate=False)
- qlng = select([zips.c.longitude], zips.c.zipcode == zip, scalar=True, correlate=False)
+ qlat = select([zips.c.latitude], zips.c.zipcode == zip).correlate(None).as_scalar()
+ qlng = select([zips.c.longitude], zips.c.zipcode == zip).correlate(None).as_scalar()
q = select([places.c.id, places.c.nm, zips.c.zipcode, func.latlondist(qlat, qlng).label('dist')],
zips.c.zipcode==zip,
order_by = ['dist', places.c.nm]
)
- self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = :zips_zipcode_1), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_2)) AS dist FROM places, zips WHERE zips.zipcode = :zips_zipcode ORDER BY dist, places.nm")
+ self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE "
+ "zips.zipcode = :zips_zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_1)) AS dist "
+ "FROM places, zips WHERE zips.zipcode = :zips_zipcode_2 ORDER BY dist, places.nm")
zalias = zips.alias('main_zip')
qlat = select([zips.c.latitude], zips.c.zipcode == zalias.c.zipcode, scalar=True)
@@ -223,7 +252,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
order_by = ['dist', places.c.nm]
)
self.runtest(q, "SELECT places.id, places.nm, main_zip.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = main_zip.zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = main_zip.zipcode)) AS dist FROM places, zips AS main_zip ORDER BY dist, places.nm")
-
+
a1 = table2.alias('t2alias')
s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True)
j1 = table1.join(table2, table1.c.myid==table2.c.otherid)
@@ -261,28 +290,20 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
)
def testoperators(self):
- self.runtest(
- table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name = :mytable_name"
- )
-
- self.runtest(
- literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
- )
# exercise arithmetic operators
for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
(operator.sub, '-'), (operator.div, '/'),
):
for (lhs, rhs, res) in (
- ('a', table1.c.myid, ':mytable_myid %s mytable.myid'),
- ('a', literal('b'), ':literal %s :literal_1'),
+ (5, table1.c.myid, ':mytable_myid %s mytable.myid'),
+ (5, literal(5), ':literal %s :literal_1'),
(table1.c.myid, 'b', 'mytable.myid %s :mytable_myid'),
- (table1.c.myid, literal('b'), 'mytable.myid %s :literal'),
+ (table1.c.myid, literal(2.7), 'mytable.myid %s :literal'),
(table1.c.myid, table1.c.myid, 'mytable.myid %s mytable.myid'),
- (literal('a'), 'b', ':literal %s :literal_1'),
- (literal('a'), table1.c.myid, ':literal %s mytable.myid'),
- (literal('a'), literal('b'), ':literal %s :literal_1'),
+ (literal(5), 8, ':literal %s :literal_1'),
+ (literal(6), table1.c.myid, ':literal %s mytable.myid'),
+ (literal(7), literal(5.5), ':literal %s :literal_1'),
):
self.runtest(py_op(lhs, rhs), res % sql_op)
@@ -314,6 +335,25 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
"\n'" + compiled + "'\n does not match\n'" +
fwd_sql + "'\n or\n'" + rev_sql + "'")
+ self.runtest(
+ table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND mytable.name != :mytable_name"
+ )
+
+ self.runtest(
+ table1.select((table1.c.myid != 12) & ~and_(table1.c.name=='john', table1.c.name=='ed', table1.c.name=='fred')),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name AND mytable.name = :mytable_name_1 AND mytable.name = :mytable_name_2)"
+ )
+
+ self.runtest(
+ table1.select((table1.c.myid != 12) & ~table1.c.name),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name"
+ )
+
+ self.runtest(
+ literal("a") + literal("b") * literal("c"), ":literal || :literal_1 * :literal_2"
+ )
+
# test the op() function, also that its results are further usable in expressions
self.runtest(
table1.select(table1.c.myid.op('hoho')(12)==14),
@@ -374,13 +414,18 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
def testalias(self):
# test the alias for a table1. column names stay the same, table name "changes" to "foo".
self.runtest(
- select([alias(table1, 'foo')])
+ select([table1.alias('foo')])
,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo")
-
+
+ for dialect in (firebird.dialect(), oracle.dialect()):
+ self.runtest(
+ select([table1.alias('foo')])
+ ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo"
+ ,dialect=dialect)
+
self.runtest(
- select([alias(table1, 'foo')])
- ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo"
- ,dialect=firebird.dialect())
+ select([table1.alias()])
+ ,"SELECT mytable_1.myid, mytable_1.name, mytable_1.description FROM mytable AS mytable_1")
# create a select for a join of two tables. use_labels means the column names will have
# labels tablename_columnname, which become the column keys accessible off the Selectable object.
@@ -401,6 +446,12 @@ myothertable.otherid AS myothertable_otherid FROM mytable, myothertable \
WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = :t2view_mytable_myid"
)
+
+ def test_prefixes(self):
+ self.runtest(table1.select().prefix_with("SQL_CALC_FOUND_ROWS").prefix_with("SQL_SOME_WEIRD_MYSQL_THING"),
+ "SELECT SQL_CALC_FOUND_ROWS SQL_SOME_WEIRD_MYSQL_THING mytable.myid, mytable.name, mytable.description FROM mytable"
+ )
+
def testtext(self):
self.runtest(
text("select * from foo where lala = bar") ,
@@ -429,7 +480,7 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
s.append_column("column2")
s.append_whereclause("column1=12")
s.append_whereclause("column2=19")
- s.order_by("column1")
+ s = s.order_by("column1")
s.append_from("table1")
self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1")
@@ -468,7 +519,14 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
checkparams={'bar':4, 'whee': 7},
params={'bar':4, 'whee': 7, 'hoho':10},
)
-
+
+ self.runtest(
+ text("select * from foo where clock='05:06:07'"),
+ "select * from foo where clock='05:06:07'",
+ checkparams={},
+ params={},
+ )
+
dialect = postgres.dialect()
self.runtest(
text("select * from foo where lala=:bar and hoho=:whee"),
@@ -477,6 +535,13 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
params={'bar':4, 'whee': 7, 'hoho':10},
dialect=dialect
)
+ self.runtest(
+ text("select * from foo where clock='05:06:07' and mork='\:mindy'"),
+ "select * from foo where clock='05:06:07' and mork=':mindy'",
+ checkparams={},
+ params={},
+ dialect=dialect
+ )
dialect = sqlite.dialect()
self.runtest(
@@ -509,7 +574,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
def testliteral(self):
self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]),
- "SELECT :literal + :literal_1 FROM mytable")
+ "SELECT :literal || :literal_1 FROM mytable")
def testcalculatedcolumns(self):
value_tbl = table('values',
@@ -663,7 +728,7 @@ FROM myothertable ORDER BY myid \
WHERE mytable.name = :mytable_name GROUP BY mytable.myid, mytable.name UNION SELECT mytable.myid, mytable.name, mytable.description \
FROM mytable WHERE mytable.name = :mytable_name_1"
)
-
+
def test_compound_select_grouping(self):
self.runtest(
union_all(
@@ -716,6 +781,7 @@ EXISTS (select yay from foo where boo = lar)",
dialect=postgres.dialect()
)
+
self.runtest(query,
"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \
@@ -835,16 +901,16 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo
self.runtest(select([table1], table1.c.myid.in_('a', literal('b'))),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)")
- self.runtest(select([table1], table1.c.myid.in_(literal('a') + 'a')),
+ self.runtest(select([table1], table1.c.myid.in_(literal(1) + 'a')),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1")
self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :mytable_myid)")
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :mytable_myid)")
self.runtest(select([table1], table1.c.myid.in_(literal('a') + literal('a'), literal('b'))),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :literal_2)")
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :literal_2)")
- self.runtest(select([table1], table1.c.myid.in_('a', literal('b') +'b')),
+ self.runtest(select([table1], table1.c.myid.in_(1, literal(3) + 4)),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal + :literal_1)")
self.runtest(select([table1], table1.c.myid.in_(literal('a') < 'b')),
@@ -862,7 +928,7 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo
self.runtest(select([table1], table1.c.myid.in_(literal('a'), table1.c.myid +'a')),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, mytable.myid + :mytable_myid)")
- self.runtest(select([table1], table1.c.myid.in_(literal('a'), 'a' + table1.c.myid)),
+ self.runtest(select([table1], table1.c.myid.in_(literal(1), 'a' + table1.c.myid)),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :mytable_myid + mytable.myid)")
self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)),
@@ -900,16 +966,6 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
- def testlateargs(self):
- """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments
- are sent"""
-
- self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'myid':'3', 'name':'jack'})
-
- self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3'})
-
- self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'})
-
def testcast(self):
tbl = table('casttest',
column('id', Integer),
@@ -963,8 +1019,8 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
"SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))")
self.runtest(table.select((5 + table.c.field).in_(5,6)),
"SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)")
- self.runtest(table.select(not_(table.c.field == 5)),
- "SELECT op.field FROM op WHERE NOT op.field = :op_field")
+ self.runtest(table.select(not_(and_(table.c.field == 5, table.c.field == 7))),
+ "SELECT op.field FROM op WHERE NOT (op.field = :op_field AND op.field = :op_field_1)")
self.runtest(table.select(not_(table.c.field) == 5),
"SELECT op.field FROM op WHERE (NOT op.field) = :literal")
self.runtest(table.select((table.c.field == table.c.field).between(False, True)),
@@ -1019,12 +1075,17 @@ class CRUDTest(SQLTest):
values = {
table1.c.name : table1.c.name + "lala",
table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho'))
- }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1")
+ }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal), name=(mytable.name || :mytable_name) "
+ "WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal_1 || mytable.name || :literal_2")
def testcorrelatedupdate(self):
# test against a straight text subquery
- u = update(table1, values = {table1.c.name : text("select name from mytable where id=mytable.id")})
+ u = update(table1, values = {table1.c.name : text("(select name from mytable where id=mytable.id)")})
self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)")
+
+ mt = table1.alias()
+ u = update(table1, values = {table1.c.name : select([mt.c.name], mt.c.myid==table1.c.myid)})
+ self.runtest(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)")
# test against a regular constructed subquery
s = select([table2], table2.c.otherid == table1.c.myid)
@@ -1043,7 +1104,18 @@ class CRUDTest(SQLTest):
def testdelete(self):
self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
-
+
+ def testcorrelateddelete(self):
+ # test a non-correlated WHERE clause
+ s = select([table2.c.othername], table2.c.otherid == 7)
+ u = delete(table1, table1.c.name==s)
+ self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)")
+
+ # test one that is actually correlated...
+ s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
+ u = table1.delete(table1.c.name==s)
+ self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
+
class SchemaTest(SQLTest):
def testselect(self):
# these tests will fail with the MS-SQL compiler since it will alias schema-qualified tables
diff --git a/test/sql/selectable.py b/test/sql/selectable.py
index ecd8253b8..dcc855074 100755
--- a/test/sql/selectable.py
+++ b/test/sql/selectable.py
@@ -1,17 +1,13 @@
-"""tests that various From objects properly export their columns, as well as useable primary keys
-and foreign keys. Full relational algebra depends on every selectable unit behaving
-nicely with others.."""
-
+"""tests that various From objects properly export their columns, as well as
+useable primary keys and foreign keys. Full relational algebra depends on
+every selectable unit behaving nicely with others.."""
+
import testbase
-import unittest, sys, datetime
-
-
-db = testbase.db
-
from sqlalchemy import *
+from testlib import *
-
-table = Table('table1', db,
+metadata = MetaData()
+table = Table('table1', metadata,
Column('col1', Integer, primary_key=True),
Column('col2', String(20)),
Column('col3', Integer),
@@ -19,14 +15,14 @@ table = Table('table1', db,
)
-table2 = Table('table2', db,
+table2 = Table('table2', metadata,
Column('col1', Integer, primary_key=True),
Column('col2', Integer, ForeignKey('table1.col1')),
Column('col3', String(20)),
Column('coly', Integer),
)
-class SelectableTest(testbase.AssertMixin):
+class SelectableTest(AssertMixin):
def testdistance(self):
s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])
@@ -57,7 +53,7 @@ class SelectableTest(testbase.AssertMixin):
jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')
jjj = join(table, jj, table.c.col1==jj.c.bar_col1)
assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1
-
+
j2 = jjj.alias('foo')
print j2.corresponding_column(jjj.c.table1_col1)
assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1
@@ -170,8 +166,9 @@ class SelectableTest(testbase.AssertMixin):
print str(criterion)
print str(j.onclause)
self.assert_(criterion.compare(j.onclause))
+
-class PrimaryKeyTest(testbase.AssertMixin):
+class PrimaryKeyTest(AssertMixin):
def test_join_pk_collapse_implicit(self):
"""test that redundant columns in a join get 'collapsed' into a minimal primary key,
which is the root column along a chain of foreign key relationships."""
@@ -224,8 +221,7 @@ class PrimaryKeyTest(testbase.AssertMixin):
j.foreign_keys
assert list(j.primary_key) == [a.c.id]
-
-
+
if __name__ == "__main__":
testbase.main()
- \ No newline at end of file
+
diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py
index ed9de0912..659033016 100644
--- a/test/sql/testtypes.py
+++ b/test/sql/testtypes.py
@@ -1,14 +1,11 @@
-from testbase import PersistTest, AssertMixin
import testbase
import pickleable
+import datetime, os
from sqlalchemy import *
-import string,datetime, re, sys, os
import sqlalchemy.engine.url as url
-
-import sqlalchemy.types
from sqlalchemy.databases import mssql, oracle, mysql
+from testlib import *
-db = testbase.db
class MyType(types.TypeEngine):
def get_col_spec(self):
@@ -107,7 +104,7 @@ class OverrideTest(PersistTest):
def setUpAll(self):
global users
- users = Table('type_users', db,
+ users = Table('type_users', MetaData(testbase.db),
Column('user_id', Integer, primary_key = True),
# totall custom type
Column('goofy', MyType, nullable = False),
@@ -138,11 +135,12 @@ class ColumnsTest(AssertMixin):
'float_column': 'float_column NUMERIC(25, 2)'
}
+ db = testbase.db
if not db.name=='sqlite' and not db.name=='oracle':
expectedResults['float_column'] = 'float_column FLOAT(25)'
print db.engine.__module__
- testTable = Table('testColumns', db,
+ testTable = Table('testColumns', MetaData(db),
Column('int_column', Integer),
Column('smallint_column', Smallinteger),
Column('varchar_column', String(20)),
@@ -157,7 +155,8 @@ class UnicodeTest(AssertMixin):
"""tests the Unicode type. also tests the TypeDecorator with instances in the types package."""
def setUpAll(self):
global unicode_table
- unicode_table = Table('unicode_table', db,
+ metadata = MetaData(testbase.db)
+ unicode_table = Table('unicode_table', metadata,
Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True),
Column('unicode_varchar', Unicode(250)),
Column('unicode_text', Unicode),
@@ -175,49 +174,49 @@ class UnicodeTest(AssertMixin):
unicode_text=unicodedata,
plain_varchar=rawdata)
x = unicode_table.select().execute().fetchone()
- self.echo(repr(x['unicode_varchar']))
- self.echo(repr(x['unicode_text']))
- self.echo(repr(x['plain_varchar']))
+ print repr(x['unicode_varchar'])
+ print repr(x['unicode_text'])
+ print repr(x['plain_varchar'])
self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
if isinstance(x['plain_varchar'], unicode):
# SQLLite and MSSQL return non-unicode data as unicode
- self.assert_(db.name in ('sqlite', 'mssql'))
+ self.assert_(testbase.db.name in ('sqlite', 'mssql'))
self.assert_(x['plain_varchar'] == unicodedata)
- self.echo("it's %s!" % db.name)
+ print "it's %s!" % testbase.db.name
else:
self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
def testengineparam(self):
"""tests engine-wide unicode conversion"""
- prev_unicode = db.engine.dialect.convert_unicode
+ prev_unicode = testbase.db.engine.dialect.convert_unicode
try:
- db.engine.dialect.convert_unicode = True
+ testbase.db.engine.dialect.convert_unicode = True
rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
unicodedata = rawdata.decode('utf-8')
unicode_table.insert().execute(unicode_varchar=unicodedata,
unicode_text=unicodedata,
plain_varchar=rawdata)
x = unicode_table.select().execute().fetchone()
- self.echo(repr(x['unicode_varchar']))
- self.echo(repr(x['unicode_text']))
- self.echo(repr(x['plain_varchar']))
+ print repr(x['unicode_varchar'])
+ print repr(x['unicode_text'])
+ print repr(x['plain_varchar'])
self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata)
finally:
- db.engine.dialect.convert_unicode = prev_unicode
+ testbase.db.engine.dialect.convert_unicode = prev_unicode
- @testbase.unsupported('oracle')
+ @testing.unsupported('oracle')
def testlength(self):
"""checks the database correctly understands the length of a unicode string"""
teststr = u'aaa\x1234'
- self.assert_(db.func.length(teststr).scalar() == len(teststr))
+ self.assert_(testbase.db.func.length(teststr).scalar() == len(teststr))
class BinaryTest(AssertMixin):
def setUpAll(self):
global binary_table
- binary_table = Table('binary_table', db,
+ binary_table = Table('binary_table', MetaData(testbase.db),
Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
Column('data', Binary),
Column('data_slice', Binary(100)),
@@ -244,39 +243,31 @@ class BinaryTest(AssertMixin):
binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1)
binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None)
- l = binary_table.select(order_by=binary_table.c.primary_id).execute().fetchall()
- print type(stream1), type(l[0]['data']), type(l[0]['data_slice'])
- print len(stream1), len(l[0]['data']), len(l[0]['data_slice'])
- self.assert_(list(stream1) == list(l[0]['data']))
- self.assert_(list(stream1[0:100]) == list(l[0]['data_slice']))
- self.assert_(list(stream2) == list(l[1]['data']))
- self.assert_(testobj1 == l[0]['pickled'])
- self.assert_(testobj2 == l[1]['pickled'])
+
+ for stmt in (
+ binary_table.select(order_by=binary_table.c.primary_id),
+ text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db)
+ ):
+ l = stmt.execute().fetchall()
+ print type(stream1), type(l[0]['data']), type(l[0]['data_slice'])
+ print len(stream1), len(l[0]['data']), len(l[0]['data_slice'])
+ self.assert_(list(stream1) == list(l[0]['data']))
+ self.assert_(list(stream1[0:100]) == list(l[0]['data_slice']))
+ self.assert_(list(stream2) == list(l[1]['data']))
+ self.assert_(testobj1 == l[0]['pickled'])
+ self.assert_(testobj2 == l[1]['pickled'])
def load_stream(self, name, len=12579):
f = os.path.join(os.path.dirname(testbase.__file__), name)
# put a number less than the typical MySQL default BLOB size
return file(f).read(len)
- @testbase.supported('oracle')
- def test_oracle_autobinary(self):
- stream1 =self.load_stream('binary_data_one.dat')
- stream2 =self.load_stream('binary_data_two.dat')
- binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100])
- binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99])
- binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None)
- result = testbase.db.connect().execute("select primary_id, misc, data, data_slice from binary_table")
- l = result.fetchall()
- l[0]['data']
- self.assert_(list(stream1) == list(l[0]['data']))
- self.assert_(list(stream1[0:100]) == list(l[0]['data_slice']))
- self.assert_(list(stream2) == list(l[1]['data']))
-
class DateTest(AssertMixin):
def setUpAll(self):
global users_with_date, insert_data
+ db = testbase.db
if db.engine.name == 'oracle':
import sqlalchemy.databases.oracle as oracle
insert_data = [
@@ -314,13 +305,14 @@ class DateTest(AssertMixin):
if db.engine.name == 'mssql':
# MSSQL Datetime values have only a 3.33 milliseconds precision
insert_data[2] = [9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35, 547000), datetime.date(1970,4,1), datetime.time(23,59,59,997000)]
-
+
fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time']
collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)),
Column('user_date', Date), Column('user_time', Time)]
- users_with_date = Table('query_users_with_date', db, *collist)
+ users_with_date = Table('query_users_with_date',
+ MetaData(testbase.db), *collist)
users_with_date.create()
insert_dicts = [dict(zip(fnames, d)) for d in insert_data]
@@ -338,7 +330,7 @@ class DateTest(AssertMixin):
def testtextdate(self):
- x = db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall()
+ x = testbase.db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall()
print repr(x)
self.assert_(isinstance(x[0][0], datetime.datetime))
@@ -347,9 +339,13 @@ class DateTest(AssertMixin):
#print repr(x)
def testdate2(self):
- t = Table('testdate', testbase.metadata, Column('id', Integer, Sequence('datetest_id_seq', optional=True), primary_key=True),
+ meta = MetaData(testbase.db)
+ t = Table('testdate', meta,
+ Column('id', Integer,
+ Sequence('datetest_id_seq', optional=True),
+ primary_key=True),
Column('adate', Date), Column('adatetime', DateTime))
- t.create()
+ t.create(checkfirst=True)
try:
d1 = datetime.date(2007, 10, 30)
t.insert().execute(adate=d1, adatetime=d1)
@@ -361,8 +357,43 @@ class DateTest(AssertMixin):
self.assert_(x.adatetime.__class__ == datetime.datetime)
finally:
- t.drop()
+ t.drop(checkfirst=True)
+class NumericTest(AssertMixin):
+ def setUpAll(self):
+ global numeric_table, metadata
+ metadata = MetaData(testbase.db)
+ numeric_table = Table('numeric_table', metadata,
+ Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True),
+ Column('numericcol', Numeric(asdecimal=False)),
+ Column('floatcol', Float),
+ Column('ncasdec', Numeric),
+ Column('fcasdec', Float(asdecimal=True))
+ )
+ metadata.create_all()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ def tearDown(self):
+ numeric_table.delete().execute()
+
+ def test_decimal(self):
+ from decimal import Decimal
+ numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78)
+ numeric_table.insert().execute(numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.78"))
+ l = numeric_table.select().execute().fetchall()
+ print l
+ rounded = [
+ (l[0][0], l[0][1], round(l[0][2], 5), l[0][3], l[0][4]),
+ (l[1][0], l[1][1], round(l[1][2], 5), l[1][3], l[1][4]),
+ ]
+ assert rounded == [
+ (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
+ (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
+ ]
+
+
class IntervalTest(AssertMixin):
def setUpAll(self):
global interval_table, metadata
diff --git a/test/sql/unicode.py b/test/sql/unicode.py
index 7ce42bf4c..f882c2a5f 100644
--- a/test/sql/unicode.py
+++ b/test/sql/unicode.py
@@ -1,23 +1,27 @@
# coding: utf-8
-import testbase
+"""verrrrry basic unicode column name testing"""
+import testbase
from sqlalchemy import *
+from sqlalchemy.orm import mapper, relation, create_session, eagerload
+from testlib import *
-"""verrrrry basic unicode column name testing"""
-class UnicodeSchemaTest(testbase.PersistTest):
+class UnicodeSchemaTest(PersistTest):
def setUpAll(self):
- global metadata, t1, t2
- metadata = MetaData(engine=testbase.db)
+ global unicode_bind, metadata, t1, t2
+
+ unicode_bind = self._unicode_bind()
+
+ metadata = MetaData(unicode_bind)
t1 = Table('unitable1', metadata,
Column(u'méil', Integer, primary_key=True),
- Column(u'éXXm', Integer),
+ Column(u'\u6e2c\u8a66', Integer),
)
- t2 = Table(u'unitéble2', metadata,
+ t2 = Table(u'Unitéble2', metadata,
Column(u'méil', Integer, primary_key=True, key="a"),
- Column(u'éXXm', Integer, ForeignKey(u'unitable1.méil'), key="b"),
-
+ Column(u'\u6e2c\u8a66', Integer, ForeignKey(u'unitable1.méil'), key="b"),
)
metadata.create_all()
@@ -26,24 +30,46 @@ class UnicodeSchemaTest(testbase.PersistTest):
t1.delete().execute()
def tearDownAll(self):
+ global unicode_bind
metadata.drop_all()
+ del unicode_bind
+
+ def _unicode_bind(self):
+ if testbase.db.name != 'mysql':
+ return testbase.db
+ else:
+ # most mysql installations don't default to utf8 connections
+ version = testbase.db.dialect.get_version_info(testbase.db)
+ if version < (4, 1):
+ raise AssertionError("Unicode not supported on MySQL < 4.1")
+
+ c = testbase.db.connect()
+ if not hasattr(c.connection.connection, 'set_character_set'):
+ raise AssertionError(
+ "Unicode not supported on this MySQL-python version")
+ else:
+ c.connection.set_character_set('utf8')
+ c.detach()
+
+ return c
def test_insert(self):
- t1.insert().execute({u'méil':1, u'éXXm':5})
+ t1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5})
t2.insert().execute({'a':1, 'b':1})
assert t1.select().execute().fetchall() == [(1, 5)]
assert t2.select().execute().fetchall() == [(1, 1)]
def test_reflect(self):
- t1.insert().execute({u'méil':2, u'éXXm':7})
+ t1.insert().execute({u'méil':2, u'\u6e2c\u8a66':7})
t2.insert().execute({'a':2, 'b':2})
- meta = MetaData(testbase.db)
+ meta = MetaData(unicode_bind)
tt1 = Table(t1.name, meta, autoload=True)
tt2 = Table(t2.name, meta, autoload=True)
- tt1.insert().execute({u'méil':1, u'éXXm':5})
- tt2.insert().execute({u'méil':1, u'éXXm':1})
+
+ tt1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5})
+ tt2.insert().execute({u'méil':1, u'\u6e2c\u8a66':1})
assert tt1.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 7), (1, 5)]
assert tt2.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 2), (1, 1)]
@@ -57,7 +83,7 @@ class UnicodeSchemaTest(testbase.PersistTest):
mapper(A, t1, properties={
't2s':relation(B),
'a':t1.c[u'méil'],
- 'b':t1.c[u'éXXm']
+ 'b':t1.c[u'\u6e2c\u8a66']
})
mapper(B, t2)
sess = create_session()
diff --git a/test/testbase.py b/test/testbase.py
index 7c5095d1a..1195db340 100644
--- a/test/testbase.py
+++ b/test/testbase.py
@@ -1,470 +1,14 @@
-import sys
-import os, unittest, StringIO, re, ConfigParser
-sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
-import sqlalchemy
-from sqlalchemy import sql, engine, pool
-import sqlalchemy.engine.base as base
-import optparse
-from sqlalchemy.schema import MetaData
-from sqlalchemy.orm import clear_mappers
-
-db = None
-metadata = None
-db_uri = None
-echo = True
-
-# redefine sys.stdout so all those print statements go to the echo func
-local_stdout = sys.stdout
-class Logger(object):
- def write(self, msg):
- if echo:
- local_stdout.write(msg)
- def flush(self):
- pass
-
-def echo_text(text):
- print text
-
-def parse_argv():
- # we are using the unittest main runner, so we are just popping out the
- # arguments we need instead of using our own getopt type of thing
- global db, db_uri, metadata
-
- DBTYPE = 'sqlite'
- PROXY = False
-
- base_config = """
-[db]
-sqlite=sqlite:///:memory:
-sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test
-oracle=oracle://scott:tiger@127.0.0.1:1521
-oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
-mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
-firebird=firebird://sysdba:s@localhost/tmp/test.fdb
-"""
- config = ConfigParser.ConfigParser()
- config.readfp(StringIO.StringIO(base_config))
- config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
-
- parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
- parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)")
- parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (%s)" % ', '.join(config.options('db')))
- parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool (asserts only one connection used)")
- parser.add_option("--verbose", action="store_true", dest="verbose", help="enable stdout echoing/printing")
- parser.add_option("--quiet", action="store_true", dest="quiet", help="suppress unittest output")
- parser.add_option("--log-info", action="append", dest="log_info", help="turn on info logging for <LOG> (multiple OK)")
- parser.add_option("--log-debug", action="append", dest="log_debug", help="turn on debug logging for <LOG> (multiple OK)")
- parser.add_option("--nothreadlocal", action="store_true", dest="nothreadlocal", help="dont use thread-local mod")
- parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)")
- parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running")
- parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)")
- parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG")
- parser.add_option("--require", action="append", dest="require", help="Require a particular driver or module version", default=[])
-
- (options, args) = parser.parse_args()
- sys.argv[1:] = args
-
- if options.dburi:
- db_uri = param = options.dburi
- DBTYPE = db_uri[:db_uri.index(':')]
- elif options.db:
- DBTYPE = param = options.db
-
- if options.require or (config.has_section('require') and
- config.items('require')):
- try:
- import pkg_resources
- except ImportError:
- raise "setuptools is required for version requirements"
-
- cmdline = []
- for requirement in options.require:
- pkg_resources.require(requirement)
- cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
-
- if config.has_section('require'):
- for label, requirement in config.items('require'):
- if not label == DBTYPE or label.startswith('%s.' % DBTYPE):
- continue
- seen = [c for c in cmdline if requirement.startswith(c)]
- if seen:
- continue
- pkg_resources.require(requirement)
-
- opts = {}
- if (None == db_uri):
- if DBTYPE not in config.options('db'):
- raise ("Could not create engine. specify --db <%s> to "
- "test runner." % '|'.join(config.options('db')))
-
- db_uri = config.get('db', DBTYPE)
-
- if not db_uri:
- raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql|firebird> to test runner."
-
- if not options.nothreadlocal:
- __import__('sqlalchemy.mods.threadlocal')
- sqlalchemy.mods.threadlocal.uninstall_plugin()
-
- global echo
- echo = options.verbose and not options.quiet
-
- global quiet
- quiet = options.quiet
-
- global with_coverage
- with_coverage = options.coverage
-
- if options.serverside:
- opts['server_side_cursors'] = True
-
- if options.enginestrategy is not None:
- opts['strategy'] = options.enginestrategy
- if options.mockpool:
- db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts)
- else:
- db = engine.create_engine(db_uri, **opts)
-
- # decorate the dialect's create_execution_context() method
- # to produce a wrapper
- create_context = db.dialect.create_execution_context
- def create_exec_context(*args, **kwargs):
- return ExecutionContextWrapper(create_context(*args, **kwargs))
- db.dialect.create_execution_context = create_exec_context
-
- global testdata
- testdata = TestData(db)
-
- if options.topological:
- from sqlalchemy.orm import unitofwork
- from sqlalchemy import topological
- class RevQueueDepSort(topological.QueueDependencySorter):
- def __init__(self, tuples, allitems):
- self.tuples = list(tuples)
- self.allitems = list(allitems)
- self.tuples.reverse()
- self.allitems.reverse()
- topological.QueueDependencySorter = RevQueueDepSort
- unitofwork.DependencySorter = RevQueueDepSort
-
- import logging
- logging.basicConfig()
- if options.log_info is not None:
- for elem in options.log_info:
- logging.getLogger(elem).setLevel(logging.INFO)
- if options.log_debug is not None:
- for elem in options.log_debug:
- logging.getLogger(elem).setLevel(logging.DEBUG)
- metadata = sqlalchemy.MetaData(db)
-
-def unsupported(*dbs):
- """a decorator that marks a test as unsupported by one or more database implementations"""
- def decorate(func):
- name = db.name
- for d in dbs:
- if d == name:
- def lala(self):
- echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'")
- lala.__name__ = func.__name__
- return lala
- else:
- return func
- return decorate
-
-def supported(*dbs):
- """a decorator that marks a test as supported by one or more database implementations"""
- def decorate(func):
- name = db.name
- for d in dbs:
- if d == name:
- return func
- else:
- def lala(self):
- echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'")
- lala.__name__ = func.__name__
- return lala
- return decorate
+"""First import for all test cases, sets sys.path and loads configuration."""
-
-class PersistTest(unittest.TestCase):
- """persist base class, provides default setUpAll, tearDownAll and echo functionality"""
- def __init__(self, *args, **params):
- unittest.TestCase.__init__(self, *args, **params)
- def echo(self, text):
- echo_text(text)
- def install_threadlocal(self):
- sqlalchemy.mods.threadlocal.install_plugin()
- def uninstall_threadlocal(self):
- sqlalchemy.mods.threadlocal.uninstall_plugin()
- def setUpAll(self):
- pass
- def tearDownAll(self):
- pass
- def shortDescription(self):
- """overridden to not return docstrings"""
- return None
+__all__ = 'db',
-class AssertMixin(PersistTest):
- """given a list-based structure of keys/properties which represent information within an object structure, and
- a list of actual objects, asserts that the list of objects corresponds to the structure."""
- def assert_result(self, result, class_, *objects):
- result = list(result)
- if echo:
- print repr(result)
- self.assert_list(result, class_, objects)
- def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list), "result list is not the same size as test list, for class " + class_.__name__)
- for i in range(0, len(list)):
- self.assert_row(class_, result[i], list[i])
- def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_))
- for key, value in desc.iteritems():
- if isinstance(value, tuple):
- if isinstance(value[1], list):
- self.assert_list(getattr(rowobj, key), value[0], value[1])
- else:
- self.assert_row(value[0], getattr(rowobj, key), value[1])
- else:
- self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
- def assert_sql(self, db, callable_, list, with_sequences=None):
- global testdata
- testdata = TestData(db)
- if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'):
- testdata.set_assert_list(self, with_sequences)
- else:
- testdata.set_assert_list(self, list)
- try:
- callable_()
- finally:
- testdata.set_assert_list(None, None)
-
- def assert_sql_count(self, db, callable_, count):
- global testdata
- testdata = TestData(db)
- try:
- callable_()
- finally:
- self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % (count, testdata.sql_count))
-
- def capture_sql(self, db, callable_):
- global testdata
- testdata = TestData(db)
- buffer = StringIO.StringIO()
- testdata.buffer = buffer
- try:
- callable_()
- return buffer.getvalue()
- finally:
- testdata.buffer = None
-
-class ORMTest(AssertMixin):
- keep_mappers = False
- keep_data = False
- def setUpAll(self):
- global metadata
- metadata = MetaData(db)
- self.define_tables(metadata)
- metadata.create_all()
- self.insert_data()
- def define_tables(self, metadata):
- raise NotImplementedError()
- def insert_data(self):
- pass
- def get_metadata(self):
- return metadata
- def tearDownAll(self):
- metadata.drop_all()
- def tearDown(self):
- if not self.keep_mappers:
- clear_mappers()
- if not self.keep_data:
- for t in metadata.table_iterator(reverse=True):
- t.delete().execute().close()
-
-class TestData(object):
- def __init__(self, engine):
- self._engine = engine
- self.logger = engine.logger
- self.set_assert_list(None, None)
- self.sql_count = 0
- self.buffer = None
-
- def set_assert_list(self, unittest, list):
- self.unittest = unittest
- self.assert_list = list
- if list is not None:
- self.assert_list.reverse()
-
-class ExecutionContextWrapper(object):
- def __init__(self, ctx):
- self.__dict__['ctx'] = ctx
- def __getattr__(self, key):
- return getattr(self.ctx, key)
- def __setattr__(self, key, value):
- setattr(self.ctx, key, value)
-
- def post_exec(self):
- ctx = self.ctx
- statement = unicode(ctx.compiled)
- statement = re.sub(r'\n', '', ctx.statement)
- if testdata.buffer is not None:
- testdata.buffer.write(statement + "\n")
-
- if testdata.assert_list is not None:
- item = testdata.assert_list[-1]
- if not isinstance(item, dict):
- item = testdata.assert_list.pop()
- else:
- # asserting a dictionary of statements->parameters
- # this is to specify query assertions where the queries can be in
- # multiple orderings
- if not item.has_key('_converted'):
- for key in item.keys():
- ckey = self.convert_statement(key)
- item[ckey] = item[key]
- if ckey != key:
- del item[key]
- item['_converted'] = True
- try:
- entry = item.pop(statement)
- if len(item) == 1:
- testdata.assert_list.pop()
- item = (statement, entry)
- except KeyError:
- self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
-
- (query, params) = item
- if callable(params):
- params = params(ctx)
- if params is not None and isinstance(params, list) and len(params) == 1:
- params = params[0]
-
- if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
- parameters = ctx.compiled_parameters.get_original_dict()
- elif isinstance(ctx.compiled_parameters, list):
- parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
-
- query = self.convert_statement(query)
- if db.engine.name == 'mssql' and statement.endswith('; select scope_identity()'):
- statement = statement[:-25]
- testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- testdata.sql_count += 1
- self.ctx.post_exec()
-
- def convert_statement(self, query):
- paramstyle = self.ctx.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle =='pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle=='qmark':
- repl = "?"
- elif paramstyle=='format':
- repl = r"%s"
- elif paramstyle=='numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
- return query
-
-class TTestSuite(unittest.TestSuite):
- """override unittest.TestSuite to provide per-TestCase class setUpAll() and tearDownAll() functionality"""
- def __init__(self, tests=()):
- if len(tests) >0 and isinstance(tests[0], PersistTest):
- self._initTest = tests[0]
- else:
- self._initTest = None
- unittest.TestSuite.__init__(self, tests)
-
- def do_run(self, result):
- """nice job unittest ! you switched __call__ and run() between py2.3 and 2.4 thereby
- making straight subclassing impossible !"""
- for test in self._tests:
- if result.shouldStop:
- break
- test(result)
- return result
-
- def run(self, result):
- return self(result)
-
- def __call__(self, result):
- try:
- if self._initTest is not None:
- self._initTest.setUpAll()
- except:
- result.addError(self._initTest, self.__exc_info())
- pass
- try:
- return self.do_run(result)
- finally:
- try:
- if self._initTest is not None:
- self._initTest.tearDownAll()
- except:
- result.addError(self._initTest, self.__exc_info())
- pass
-
- def __exc_info(self):
- """Return a version of sys.exc_info() with the traceback frame
- minimised; usually the top level of the traceback frame is not
- needed.
- ripped off out of unittest module since its double __
- """
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
-
-unittest.TestLoader.suiteClass = TTestSuite
-
-parse_argv()
-
-
-def runTests(suite):
- sys.stdout = Logger()
- runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
- if with_coverage:
- return cover(lambda:runner.run(suite))
- else:
- return runner.run(suite)
-
-def covered_files():
- for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
- for x in rec[2]:
- if x.endswith('.py'):
- yield os.path.join(rec[0], x)
-
-def cover(callable_):
- import coverage
- coverage_client = coverage.the_coverage
- coverage_client.get_ready()
- coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
- coverage_client.erase()
- coverage_client.start()
- try:
- return callable_()
- finally:
- global echo
- echo=True
- coverage_client.stop()
- coverage_client.save()
- coverage_client.report(list(covered_files()), show_missing=False, ignore_errors=False)
-
-def main(suite=None):
-
- if not suite:
- if len(sys.argv[1:]):
- suite =unittest.TestLoader().loadTestsFromNames(sys.argv[1:], __import__('__main__'))
- else:
- suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
-
- result = runTests(suite)
- sys.exit(not result.wasSuccessful())
+import sys, os, logging
+sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
+logging.basicConfig()
+import testlib.config
+testlib.config.configure()
+from testlib.testing import main
+db = testlib.config.db
diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py
new file mode 100644
index 000000000..ff5c4c125
--- /dev/null
+++ b/test/testlib/__init__.py
@@ -0,0 +1,11 @@
+"""Enhance unittest and instrument SQLAlchemy classes for testing.
+
+Load after sqlalchemy imports to use instrumented stand-ins like Table.
+"""
+
+import testlib.config
+from testlib.schema import Table, Column
+import testlib.testing as testing
+from testlib.testing import PersistTest, AssertMixin, ORMTest
+import testlib.profiling
+
diff --git a/test/testlib/config.py b/test/testlib/config.py
new file mode 100644
index 000000000..f05cda46d
--- /dev/null
+++ b/test/testlib/config.py
@@ -0,0 +1,255 @@
+import optparse, os, sys, ConfigParser, StringIO
+logging, require = None, None
+
+__all__ = 'parser', 'configure', 'options',
+
+db, db_uri, db_type, db_label = None, None, None, None
+
+options = None
+file_config = None
+
+base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
+firebird=firebird://sysdba:s@localhost/tmp/test.fdb
+"""
+
+parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
+
+def configure():
+ global options, config
+ global getopts_options, file_config
+
+ file_config = ConfigParser.ConfigParser()
+ file_config.readfp(StringIO.StringIO(base_config))
+ file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+
+ # Opt parsing can fire immediate actions, like logging and coverage
+ (options, args) = parser.parse_args()
+ sys.argv[1:] = args
+
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ return options, file_config
+
+def _log(option, opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+ logging.basicConfig()
+
+ if opt_str.endswith('-info'):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith('-debug'):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+def _start_coverage(option, opt_str, value, parser):
+ import sys, atexit, coverage
+ true_out = sys.stdout
+
+ def _iter_covered_files():
+ import sqlalchemy
+ for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
+ for x in rec[2]:
+ if x.endswith('.py'):
+ yield os.path.join(rec[0], x)
+ def _stop():
+ coverage.stop()
+ true_out.write("\nPreparing coverage report...\n")
+ coverage.report(list(_iter_covered_files()),
+ show_missing=False, ignore_errors=False,
+ file=true_out)
+ atexit.register(_stop)
+ coverage.erase()
+ coverage.start()
+
+def _list_dbs(*args):
+ print "Available --db options (use --dburi to override)"
+ for macro in sorted(file_config.options('db')):
+ print "%20s\t%s" % (macro, file_config.get('db', macro))
+ sys.exit(0)
+
+opt = parser.add_option
+opt("--verbose", action="store_true", dest="verbose",
+ help="enable stdout echoing/printing")
+opt("--quiet", action="store_true", dest="quiet", help="suppress output")
+opt("--log-info", action="callback", type="string", callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)")
+opt("--log-debug", action="callback", type="string", callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)")
+opt("--require", action="append", dest="require", default=[],
+ help="require a particular driver or module version (multiple OK)")
+opt("--db", action="store", dest="db", default="sqlite",
+ help="Use prefab database uri")
+opt('--dbs', action='callback', callback=_list_dbs,
+ help="List available prefab dbs")
+opt("--dburi", action="store", dest="dburi",
+ help="Database uri (overrides --db)")
+opt("--mockpool", action="store_true", dest="mockpool",
+ help="Use mock pool (asserts only one connection used)")
+opt("--enginestrategy", action="store", dest="enginestrategy", default=None,
+ help="Engine strategy (plain or threadlocal, defaults toplain)")
+opt("--reversetop", action="store_true", dest="reversetop", default=False,
+ help="Reverse the collection ordering for topological sorts (helps "
+ "reveal dependency issues)")
+opt("--serverside", action="store_true", dest="serverside",
+ help="Turn on server side cursors for PG")
+opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+ help="Use the specified MySQL storage engine for all tables, default is "
+ "a db-default/InnoDB combo.")
+opt("--table-option", action="append", dest="tableopts", default=[],
+ help="Add a dialect-specific table option, key=value")
+opt("--coverage", action="callback", callback=_start_coverage,
+ help="Dump a full coverage report after running tests")
+opt("--profile", action="append", dest="profile_targets", default=[],
+ help="Enable a named profile target (multiple OK.)")
+opt("--profile-sort", action="store", dest="profile_sort", default=None,
+ help="Sort profile stats with this comma-separated sort order")
+opt("--profile-limit", type="int", action="store", dest="profile_limit",
+ default=None,
+ help="Limit function count in profile stats")
+
+class _ordered_map(object):
+ def __init__(self):
+ self._keys = list()
+ self._data = dict()
+
+ def __setitem__(self, key, value):
+ if key not in self._keys:
+ self._keys.append(key)
+ self._data[key] = value
+
+ def __iter__(self):
+ for key in self._keys:
+ yield self._data[key]
+
+post_configure = _ordered_map()
+
+def _engine_uri(options, file_config):
+ global db_label, db_uri
+ db_label = 'sqlite'
+ if options.dburi:
+ db_uri = options.dburi
+ db_label = db_uri[:db_uri.index(':')]
+ elif options.db:
+ db_label = options.db
+ db_uri = None
+
+ if db_uri is None:
+ if db_label not in file_config.options('db'):
+ raise RuntimeError(
+ "Unknown engine. Specify --dbs for known engines.")
+ db_uri = file_config.get('db', db_label)
+post_configure['engine_uri'] = _engine_uri
+
+def _require(options, file_config):
+ if not(options.require or
+ (file_config.has_section('require') and
+ file_config.items('require'))):
+ return
+
+ try:
+ import pkg_resources
+ except ImportError:
+ raise RuntimeError("setuptools is required for version requirements")
+
+ cmdline = []
+ for requirement in options.require:
+ pkg_resources.require(requirement)
+ cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+ if file_config.has_section('require'):
+ for label, requirement in file_config.items('require'):
+ if not label == db_label or label.startswith('%s.' % db_label):
+ continue
+ seen = [c for c in cmdline if requirement.startswith(c)]
+ if seen:
+ continue
+ pkg_resources.require(requirement)
+post_configure['require'] = _require
+
+def _create_testing_engine(options, file_config):
+ from sqlalchemy import engine
+ global db, db_type
+ engine_opts = {}
+ if options.serverside:
+ engine_opts['server_side_cursors'] = True
+
+ if options.enginestrategy is not None:
+ engine_opts['strategy'] = options.enginestrategy
+
+ if options.mockpool:
+ db = engine.create_engine(db_uri, poolclass=pool.AssertionPool,
+ **engine_opts)
+ else:
+ db = engine.create_engine(db_uri, **engine_opts)
+ db_type = db.name
+
+ # decorate the dialect's create_execution_context() method
+ # to produce a wrapper
+ from testlib.testing import ExecutionContextWrapper
+
+ create_context = db.dialect.create_execution_context
+ def create_exec_context(*args, **kwargs):
+ return ExecutionContextWrapper(create_context(*args, **kwargs))
+ db.dialect.create_execution_context = create_exec_context
+post_configure['create_engine'] = _create_testing_engine
+
+def _set_table_options(options, file_config):
+ import testlib.schema
+
+ table_options = testlib.schema.table_options
+ for spec in options.tableopts:
+ key, value = spec.split('=')
+ table_options[key] = value
+
+ if options.mysql_engine:
+ table_options['mysql_engine'] = options.mysql_engine
+post_configure['table_options'] = _set_table_options
+
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm import unitofwork
+ from sqlalchemy import topological
+ class RevQueueDepSort(topological.QueueDependencySorter):
+ def __init__(self, tuples, allitems):
+ self.tuples = list(tuples)
+ self.allitems = list(allitems)
+ self.tuples.reverse()
+ self.allitems.reverse()
+ topological.QueueDependencySorter = RevQueueDepSort
+ unitofwork.DependencySorter = RevQueueDepSort
+post_configure['topological'] = _reverse_topological
+
+def _set_profile_targets(options, file_config):
+ from testlib import profiling
+
+ profile_config = profiling.profile_config
+
+ for target in options.profile_targets:
+ profile_config['targets'].add(target)
+
+ if options.profile_sort:
+ profile_config['sort'] = options.profile_sort.split(',')
+
+ if options.profile_limit:
+ profile_config['limit'] = options.profile_limit
+
+ if options.quiet:
+ profile_config['report'] = False
+
+ # magic "all" target
+ if 'all' in profiling.all_targets:
+ targets = profile_config['targets']
+ if 'all' in targets and len(targets) != 1:
+ targets.clear()
+ targets.add('all')
+post_configure['profile_targets'] = _set_profile_targets
diff --git a/test/coverage.py b/test/testlib/coverage.py
index 66e55e0c4..0203dbf7d 100644
--- a/test/coverage.py
+++ b/test/testlib/coverage.py
@@ -22,7 +22,8 @@
# interface and limitations. See [GDR 2001-12-04b] for requirements and
# design.
-r"""Usage:
+r"""\
+Usage:
coverage.py -x [-p] MODULE.py [ARG1 ARG2 ...]
Execute module, passing the given command-line arguments, collecting
@@ -54,18 +55,27 @@ coverage.py -a [-d dir] [-o dir1,dir2,...] FILE1 FILE2 ...
Coverage data is saved in the file .coverage by default. Set the
COVERAGE_FILE environment variable to save it somewhere else."""
-__version__ = "2.6.20060823" # see detailed history at the end of this file.
+__version__ = "2.75.20070722" # see detailed history at the end of this file.
import compiler
import compiler.visitor
+import glob
import os
import re
import string
+import symbol
import sys
import threading
+import token
import types
from socket import gethostname
+# Python version compatibility
+try:
+ strclass = basestring # new to 2.3
+except:
+ strclass = str
+
# 2. IMPLEMENTATION
#
# This uses the "singleton" pattern.
@@ -87,6 +97,9 @@ from socket import gethostname
# names to increase speed.
class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
+ """ A visitor for a parsed Abstract Syntax Tree which finds executable
+ statements.
+ """
def __init__(self, statements, excluded, suite_spots):
compiler.visitor.ASTVisitor.__init__(self)
self.statements = statements
@@ -95,7 +108,6 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
self.excluding_suite = 0
def doRecursive(self, node):
- self.recordNodeLine(node)
for n in node.getChildNodes():
self.dispatch(n)
@@ -131,12 +143,35 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
def doStatement(self, node):
self.recordLine(self.getFirstLine(node))
- visitAssert = visitAssign = visitAssTuple = visitDiscard = visitPrint = \
+ visitAssert = visitAssign = visitAssTuple = visitPrint = \
visitPrintnl = visitRaise = visitSubscript = visitDecorators = \
doStatement
+ def visitPass(self, node):
+ # Pass statements have weird interactions with docstrings. If this
+ # pass statement is part of one of those pairs, claim that the statement
+ # is on the later of the two lines.
+ l = node.lineno
+ if l:
+ lines = self.suite_spots.get(l, [l,l])
+ self.statements[lines[1]] = 1
+
+ def visitDiscard(self, node):
+ # Discard nodes are statements that execute an expression, but then
+ # discard the results. This includes function calls, so we can't
+ # ignore them all. But if the expression is a constant, the statement
+ # won't be "executed", so don't count it now.
+ if node.expr.__class__.__name__ != 'Const':
+ self.doStatement(node)
+
def recordNodeLine(self, node):
- return self.recordLine(node.lineno)
+ # Stmt nodes often have None, but shouldn't claim the first line of
+ # their children (because the first child might be an ignorable line
+ # like "global a").
+ if node.__class__.__name__ != 'Stmt':
+ return self.recordLine(self.getFirstLine(node))
+ else:
+ return 0
def recordLine(self, lineno):
# Returns a bool, whether the line is included or excluded.
@@ -145,7 +180,7 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
# keyword.
if lineno in self.suite_spots:
lineno = self.suite_spots[lineno][0]
- # If we're inside an exluded suite, record that this line was
+ # If we're inside an excluded suite, record that this line was
# excluded.
if self.excluding_suite:
self.excluded[lineno] = 1
@@ -197,6 +232,8 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
self.doSuite(node, node.body)
self.doElse(node.body, node)
+ visitWhile = visitFor
+
def visitIf(self, node):
# The first test has to be handled separately from the rest.
# The first test is credited to the line with the "if", but the others
@@ -206,10 +243,6 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
self.doSuite(t, n)
self.doElse(node.tests[-1][1], node)
- def visitWhile(self, node):
- self.doSuite(node, node.body)
- self.doElse(node.body, node)
-
def visitTryExcept(self, node):
self.doSuite(node, node.body)
for i in range(len(node.handlers)):
@@ -268,11 +301,13 @@ class coverage:
raise CoverageException, "Only one coverage object allowed."
self.usecache = 1
self.cache = None
+ self.parallel_mode = False
self.exclude_re = ''
self.nesting = 0
self.cstack = []
self.xstack = []
- self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.path.sep)
+ self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.sep)
+ self.exclude('# *pragma[: ]*[nN][oO] *[cC][oO][vV][eE][rR]')
# t(f, x, y). This method is passed to sys.settrace as a trace function.
# See [van Rossum 2001-07-20b, 9.2] for an explanation of sys.settrace and
@@ -280,23 +315,24 @@ class coverage:
# See [van Rossum 2001-07-20a, 3.2] for a description of frame and code
# objects.
- def t(self, f, w, a): #pragma: no cover
+ def t(self, f, w, unused): #pragma: no cover
if w == 'line':
+ #print "Executing %s @ %d" % (f.f_code.co_filename, f.f_lineno)
self.c[(f.f_code.co_filename, f.f_lineno)] = 1
for c in self.cstack:
c[(f.f_code.co_filename, f.f_lineno)] = 1
return self.t
- def help(self, error=None):
+ def help(self, error=None): #pragma: no cover
if error:
print error
print
print __doc__
sys.exit(1)
- def command_line(self, argv, help=None):
+ def command_line(self, argv, help_fn=None):
import getopt
- help = help or self.help
+ help_fn = help_fn or self.help
settings = {}
optmap = {
'-a': 'annotate',
@@ -327,12 +363,12 @@ class coverage:
pass # Can't get here, because getopt won't return anything unknown.
if settings.get('help'):
- help()
+ help_fn()
for i in ['erase', 'execute']:
for j in ['annotate', 'report', 'collect']:
if settings.get(i) and settings.get(j):
- help("You can't specify the '%s' and '%s' "
+ help_fn("You can't specify the '%s' and '%s' "
"options at the same time." % (i, j))
args_needed = (settings.get('execute')
@@ -342,18 +378,18 @@ class coverage:
or settings.get('collect')
or args_needed)
if not action:
- help("You must specify at least one of -e, -x, -c, -r, or -a.")
+ help_fn("You must specify at least one of -e, -x, -c, -r, or -a.")
if not args_needed and args:
- help("Unexpected arguments: %s" % " ".join(args))
+ help_fn("Unexpected arguments: %s" % " ".join(args))
- self.get_ready(settings.get('parallel-mode'))
- self.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
+ self.parallel_mode = settings.get('parallel-mode')
+ self.get_ready()
if settings.get('erase'):
self.erase()
if settings.get('execute'):
if not args:
- help("Nothing to do.")
+ help_fn("Nothing to do.")
sys.argv = args
self.start()
import __main__
@@ -387,13 +423,13 @@ class coverage:
def get_ready(self, parallel_mode=False):
if self.usecache and not self.cache:
self.cache = os.environ.get(self.cache_env, self.cache_default)
- if parallel_mode:
+ if self.parallel_mode:
self.cache += "." + gethostname() + "." + str(os.getpid())
self.restore()
self.analysis_cache = {}
def start(self, parallel_mode=False):
- self.get_ready(parallel_mode)
+ self.get_ready()
if self.nesting == 0: #pragma: no cover
sys.settrace(self.t)
if hasattr(threading, 'settrace'):
@@ -408,12 +444,12 @@ class coverage:
threading.settrace(None)
def erase(self):
+ self.get_ready()
self.c = {}
self.analysis_cache = {}
self.cexecuted = {}
if self.cache and os.path.exists(self.cache):
os.remove(self.cache)
- self.exclude_re = ""
def exclude(self, re):
if self.exclude_re:
@@ -464,11 +500,11 @@ class coverage:
def collect(self):
cache_dir, local = os.path.split(self.cache)
- for file in os.listdir(cache_dir):
- if not file.startswith(local):
+ for f in os.listdir(cache_dir or '.'):
+ if not f.startswith(local):
continue
- full_path = os.path.join(cache_dir, file)
+ full_path = os.path.join(cache_dir, f)
cexecuted = self.restore_file(full_path)
self.merge_data(cexecuted)
@@ -508,6 +544,9 @@ class coverage:
def canonicalize_filenames(self):
for filename, lineno in self.c.keys():
+ if filename == '<string>':
+ # Can't do anything useful with exec'd strings, so skip them.
+ continue
f = self.canonical_filename(filename)
if not self.cexecuted.has_key(f):
self.cexecuted[f] = {}
@@ -520,17 +559,19 @@ class coverage:
if isinstance(morf, types.ModuleType):
if not hasattr(morf, '__file__'):
raise CoverageException, "Module has no __file__ attribute."
- file = morf.__file__
+ f = morf.__file__
else:
- file = morf
- return self.canonical_filename(file)
+ f = morf
+ return self.canonical_filename(f)
# analyze_morf(morf). Analyze the module or filename passed as
# the argument. If the source code can't be found, raise an error.
# Otherwise, return a tuple of (1) the canonical filename of the
# source code for the module, (2) a list of lines of statements
- # in the source code, and (3) a list of lines of excluded statements.
-
+ # in the source code, (3) a list of lines of excluded statements,
+ # and (4), a map of line numbers to multi-line line number ranges, for
+ # statements that cross lines.
+
def analyze_morf(self, morf):
if self.analysis_cache.has_key(morf):
return self.analysis_cache[morf]
@@ -544,16 +585,53 @@ class coverage:
elif ext != '.py':
raise CoverageException, "File '%s' not Python source." % filename
source = open(filename, 'r')
- lines, excluded_lines = self.find_executable_statements(
+ lines, excluded_lines, line_map = self.find_executable_statements(
source.read(), exclude=self.exclude_re
)
source.close()
- result = filename, lines, excluded_lines
+ result = filename, lines, excluded_lines, line_map
self.analysis_cache[morf] = result
return result
+ def first_line_of_tree(self, tree):
+ while True:
+ if len(tree) == 3 and type(tree[2]) == type(1):
+ return tree[2]
+ tree = tree[1]
+
+ def last_line_of_tree(self, tree):
+ while True:
+ if len(tree) == 3 and type(tree[2]) == type(1):
+ return tree[2]
+ tree = tree[-1]
+
+ def find_docstring_pass_pair(self, tree, spots):
+ for i in range(1, len(tree)):
+ if self.is_string_constant(tree[i]) and self.is_pass_stmt(tree[i+1]):
+ first_line = self.first_line_of_tree(tree[i])
+ last_line = self.last_line_of_tree(tree[i+1])
+ self.record_multiline(spots, first_line, last_line)
+
+ def is_string_constant(self, tree):
+ try:
+ return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.expr_stmt
+ except:
+ return False
+
+ def is_pass_stmt(self, tree):
+ try:
+ return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.pass_stmt
+ except:
+ return False
+
+ def record_multiline(self, spots, i, j):
+ for l in range(i, j+1):
+ spots[l] = (i, j)
+
def get_suite_spots(self, tree, spots):
- import symbol, token
+ """ Analyze a parse tree to find suite introducers which span a number
+ of lines.
+ """
for i in range(1, len(tree)):
if type(tree[i]) == type(()):
if tree[i][0] == symbol.suite:
@@ -561,7 +639,9 @@ class coverage:
lineno_colon = lineno_word = None
for j in range(i-1, 0, -1):
if tree[j][0] == token.COLON:
- lineno_colon = tree[j][2]
+ # Colons are never executed themselves: we want the
+ # line number of the last token before the colon.
+ lineno_colon = self.last_line_of_tree(tree[j-1])
elif tree[j][0] == token.NAME:
if tree[j][1] == 'elif':
# Find the line number of the first non-terminal
@@ -583,8 +663,18 @@ class coverage:
if lineno_colon and lineno_word:
# Found colon and keyword, mark all the lines
# between the two with the two line numbers.
- for l in range(lineno_word, lineno_colon+1):
- spots[l] = (lineno_word, lineno_colon)
+ self.record_multiline(spots, lineno_word, lineno_colon)
+
+ # "pass" statements are tricky: different versions of Python
+ # treat them differently, especially in the common case of a
+ # function with a doc string and a single pass statement.
+ self.find_docstring_pass_pair(tree[i], spots)
+
+ elif tree[i][0] == symbol.simple_stmt:
+ first_line = self.first_line_of_tree(tree[i])
+ last_line = self.last_line_of_tree(tree[i])
+ if first_line != last_line:
+ self.record_multiline(spots, first_line, last_line)
self.get_suite_spots(tree[i], spots)
def find_executable_statements(self, text, exclude=None):
@@ -598,10 +688,13 @@ class coverage:
if reExclude.search(lines[i]):
excluded[i+1] = 1
+ # Parse the code and analyze the parse tree to find out which statements
+ # are multiline, and where suites begin and end.
import parser
tree = parser.suite(text+'\n\n').totuple(1)
self.get_suite_spots(tree, suite_spots)
-
+ #print "Suite spots:", suite_spots
+
# Use the compiler module to parse the text and find the executable
# statements. We add newlines to be impervious to final partial lines.
statements = {}
@@ -613,7 +706,7 @@ class coverage:
lines.sort()
excluded_lines = excluded.keys()
excluded_lines.sort()
- return lines, excluded_lines
+ return lines, excluded_lines, suite_spots
# format_lines(statements, lines). Format a list of line numbers
# for printing by coalescing groups of lines as long as the lines
@@ -646,7 +739,8 @@ class coverage:
return "%d" % start
else:
return "%d-%d" % (start, end)
- return string.join(map(stringify, pairs), ", ")
+ ret = string.join(map(stringify, pairs), ", ")
+ return ret
# Backward compatibility with version 1.
def analysis(self, morf):
@@ -654,13 +748,17 @@ class coverage:
return f, s, m, mf
def analysis2(self, morf):
- filename, statements, excluded = self.analyze_morf(morf)
+ filename, statements, excluded, line_map = self.analyze_morf(morf)
self.canonicalize_filenames()
if not self.cexecuted.has_key(filename):
self.cexecuted[filename] = {}
missing = []
for line in statements:
- if not self.cexecuted[filename].has_key(line):
+ lines = line_map.get(line, [line, line])
+ for l in range(lines[0], lines[1]+1):
+ if self.cexecuted[filename].has_key(l):
+ break
+ else:
missing.append(line)
return (filename, statements, excluded, missing,
self.format_lines(statements, missing))
@@ -698,6 +796,15 @@ class coverage:
def report(self, morfs, show_missing=1, ignore_errors=0, file=None, omit_prefixes=[]):
if not isinstance(morfs, types.ListType):
morfs = [morfs]
+ # On windows, the shell doesn't expand wildcards. Do it here.
+ globbed = []
+ for morf in morfs:
+ if isinstance(morf, strclass):
+ globbed.extend(glob.glob(morf))
+ else:
+ globbed.append(morf)
+ morfs = globbed
+
morfs = self.filter_by_prefix(morfs, omit_prefixes)
morfs.sort(self.morf_name_compare)
@@ -735,8 +842,8 @@ class coverage:
raise
except:
if not ignore_errors:
- type, msg = sys.exc_info()[0:2]
- print >>file, fmt_err % (name, type, msg)
+ typ, msg = sys.exc_info()[0:2]
+ print >>file, fmt_err % (name, typ, msg)
if len(morfs) > 1:
print >>file, "-" * len(header)
if total_statements > 0:
@@ -816,18 +923,41 @@ class coverage:
the_coverage = coverage()
# Module functions call methods in the singleton object.
-def use_cache(*args, **kw): return the_coverage.use_cache(*args, **kw)
-def start(*args, **kw): return the_coverage.start(*args, **kw)
-def stop(*args, **kw): return the_coverage.stop(*args, **kw)
-def erase(*args, **kw): return the_coverage.erase(*args, **kw)
-def begin_recursive(*args, **kw): return the_coverage.begin_recursive(*args, **kw)
-def end_recursive(*args, **kw): return the_coverage.end_recursive(*args, **kw)
-def exclude(*args, **kw): return the_coverage.exclude(*args, **kw)
-def analysis(*args, **kw): return the_coverage.analysis(*args, **kw)
-def analysis2(*args, **kw): return the_coverage.analysis2(*args, **kw)
-def report(*args, **kw): return the_coverage.report(*args, **kw)
-def annotate(*args, **kw): return the_coverage.annotate(*args, **kw)
-def annotate_file(*args, **kw): return the_coverage.annotate_file(*args, **kw)
+def use_cache(*args, **kw):
+ return the_coverage.use_cache(*args, **kw)
+
+def start(*args, **kw):
+ return the_coverage.start(*args, **kw)
+
+def stop(*args, **kw):
+ return the_coverage.stop(*args, **kw)
+
+def erase(*args, **kw):
+ return the_coverage.erase(*args, **kw)
+
+def begin_recursive(*args, **kw):
+ return the_coverage.begin_recursive(*args, **kw)
+
+def end_recursive(*args, **kw):
+ return the_coverage.end_recursive(*args, **kw)
+
+def exclude(*args, **kw):
+ return the_coverage.exclude(*args, **kw)
+
+def analysis(*args, **kw):
+ return the_coverage.analysis(*args, **kw)
+
+def analysis2(*args, **kw):
+ return the_coverage.analysis2(*args, **kw)
+
+def report(*args, **kw):
+ return the_coverage.report(*args, **kw)
+
+def annotate(*args, **kw):
+ return the_coverage.annotate(*args, **kw)
+
+def annotate_file(*args, **kw):
+ return the_coverage.annotate_file(*args, **kw)
# Save coverage data when Python exits. (The atexit module wasn't
# introduced until Python 2.0, so use sys.exitfunc when it's not
@@ -918,11 +1048,32 @@ if __name__ == '__main__':
#
# 2006-08-23 NMB Refactorings to improve testability. Fixes to command-line
# logic for parallel mode and collect.
+#
+# 2006-08-25 NMB "#pragma: nocover" is excluded by default.
+#
+# 2006-09-10 NMB Properly ignore docstrings and other constant expressions that
+# appear in the middle of a function, a problem reported by Tim Leslie.
+# Minor changes to avoid lint warnings.
+#
+# 2006-09-17 NMB coverage.erase() shouldn't clobber the exclude regex.
+# Change how parallel mode is invoked, and fix erase() so that it erases the
+# cache when called programmatically.
+#
+# 2007-07-21 NMB In reports, ignore code executed from strings, since we can't
+# do anything useful with it anyway.
+# Better file handling on Linux, thanks Guillaume Chazarain.
+# Better shell support on Windows, thanks Noel O'Boyle.
+# Python 2.2 support maintained, thanks Catherine Proulx.
+#
+# 2007-07-22 NMB Python 2.5 now fully supported. The method of dealing with
+# multi-line statements is now less sensitive to the exact line that Python
+# reports during execution. Pass statements are handled specially so that their
+# disappearance during execution won't throw off the measurement.
# C. COPYRIGHT AND LICENCE
#
# Copyright 2001 Gareth Rees. All rights reserved.
-# Copyright 2004-2006 Ned Batchelder. All rights reserved.
+# Copyright 2004-2007 Ned Batchelder. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@@ -949,4 +1100,4 @@ if __name__ == '__main__':
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
-# $Id: coverage.py 47 2006-08-24 01:08:48Z Ned $
+# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $
diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py
new file mode 100644
index 000000000..697df4ea2
--- /dev/null
+++ b/test/testlib/profiling.py
@@ -0,0 +1,74 @@
+"""Profiling support for unit and performance tests."""
+
+from testlib.config import parser, post_configure
+import testlib.config
+
+__all__ = 'profiled',
+
+all_targets = set()
+profile_config = { 'targets': set(),
+ 'report': True,
+ 'sort': ('time', 'calls'),
+ 'limit': None }
+
+def profiled(target, **target_opts):
+ """Optional function profiling.
+
+ @profiled('label')
+ or
+ @profiled('label', report=True, sort=('calls',), limit=20)
+
+ Enables profiling for a function when 'label' is targetted for
+ profiling. Report options can be supplied, and override the global
+ configuration and command-line options.
+ """
+
+ import time, hotshot, hotshot.stats
+
+ # manual or automatic namespacing by module would remove conflict issues
+ if target in all_targets:
+ print "Warning: redefining profile target '%s'" % target
+ all_targets.add(target)
+
+ filename = "%s.prof" % target
+
+ def decorator(fn):
+ def profiled(*args, **kw):
+ if (target not in profile_config['targets'] and
+ not target_opts.get('always', None)):
+ return fn(*args, **kw)
+
+ prof = hotshot.Profile(filename)
+ began = time.time()
+ prof.start()
+ try:
+ result = fn(*args, **kw)
+ finally:
+ prof.stop()
+ ended = time.time()
+ prof.close()
+
+ if not testlib.config.options.quiet:
+ print "Profiled target '%s', wall time: %.2f seconds" % (
+ target, ended - began)
+
+ report = target_opts.get('report', profile_config['report'])
+ if report:
+ sort_ = target_opts.get('sort', profile_config['sort'])
+ limit = target_opts.get('limit', profile_config['limit'])
+ print "Profile report for target '%s' (%s)" % (
+ target, filename)
+
+ stats = hotshot.stats.load(filename)
+ stats.sort_stats(*sort_)
+ if limit:
+ stats.print_stats(limit)
+ else:
+ stats.print_stats()
+ return result
+ try:
+ profiled.__name__ = fn.__name__
+ except:
+ pass
+ return profiled
+ return decorator
diff --git a/test/testlib/schema.py b/test/testlib/schema.py
new file mode 100644
index 000000000..a2fc91265
--- /dev/null
+++ b/test/testlib/schema.py
@@ -0,0 +1,28 @@
+import testbase
+from sqlalchemy import schema
+
+__all__ = 'Table', 'Column',
+
+table_options = {}
+
+def Table(*args, **kw):
+ """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
+ if k.startswith('test_')])
+
+ kw.update(table_options)
+
+ if testbase.db.name == 'mysql':
+ if 'mysql_engine' not in kw and 'mysql_type' not in kw:
+ if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
+ kw['mysql_engine'] = 'InnoDB'
+
+ return schema.Table(*args, **kw)
+
+def Column(*args, **kw):
+ """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+ # TODO: a Column that creates a Sequence automatically for PK columns,
+ # which would help Oracle tests
+ return schema.Column(*args, **kw)
diff --git a/test/tables.py b/test/testlib/tables.py
index 8e337999c..69c84c5b3 100644
--- a/test/tables.py
+++ b/test/testlib/tables.py
@@ -1,24 +1,24 @@
-
-from sqlalchemy import *
-import os
import testbase
+from sqlalchemy import *
+from testlib.schema import Table, Column
-ECHO = testbase.echo
-db = testbase.db
-metadata = MetaData(db)
+# these are older test fixtures, used primarily by test/orm/mapper.py and test/orm/unitofwork.py.
+# newer unit tests make usage of test/orm/fixtures.py.
+
+metadata = MetaData()
users = Table('users', metadata,
Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
Column('user_name', String(40)),
- mysql_engine='innodb'
+ test_needs_acid=True,
+ test_needs_fk=True,
)
addresses = Table('email_addresses', metadata,
Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
Column('user_id', Integer, ForeignKey(users.c.user_id)),
Column('email_address', String(40)),
-
)
orders = Table('orders', metadata,
@@ -26,20 +26,17 @@ orders = Table('orders', metadata,
Column('user_id', Integer, ForeignKey(users.c.user_id)),
Column('description', String(50)),
Column('isopen', Integer),
-
)
orderitems = Table('items', metadata,
Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
Column('order_id', INT, ForeignKey("orders")),
Column('item_name', VARCHAR(50)),
-
)
keywords = Table('keywords', metadata,
Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True),
Column('name', VARCHAR(50)),
-
)
userkeywords = Table('userkeywords', metadata,
@@ -54,13 +51,19 @@ itemkeywords = Table('itemkeywords', metadata,
)
def create():
+ if not metadata.bind:
+ metadata.bind = testbase.db
metadata.create_all()
def drop():
+ if not metadata.bind:
+ metadata.bind = testbase.db
metadata.drop_all()
def delete():
for t in metadata.table_iterator(reverse=True):
t.delete().execute()
def user_data():
+ if not metadata.bind:
+ metadata.bind = testbase.db
users.insert().execute(
dict(user_id = 7, user_name = 'jack'),
dict(user_id = 8, user_name = 'ed'),
@@ -212,4 +215,4 @@ order_result = [
{'order_id' : 4, 'items':(Item, [])},
{'order_id' : 5, 'items':(Item, [])},
]
-#db.echo = True
+
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
new file mode 100644
index 000000000..213772e9e
--- /dev/null
+++ b/test/testlib/testing.py
@@ -0,0 +1,363 @@
+"""TestCase and TestSuite artifacts and testing decorators."""
+
+# monkeypatches unittest.TestLoader.suiteClass at import time
+
+import unittest, re, sys, os
+from cStringIO import StringIO
+from sqlalchemy import MetaData, sql
+from sqlalchemy.orm import clear_mappers
+import testlib.config as config
+
+__all__ = 'PersistTest', 'AssertMixin', 'ORMTest'
+
+def unsupported(*dbs):
+ """Mark a test as unsupported by one or more database implementations"""
+
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name in dbs:
+ print "'%s' unsupported on DB implementation '%s'" % (
+ fn_name, config.db.name)
+ return True
+ else:
+ return fn(*args, **kw)
+ try:
+ maybe.__name__ = fn_name
+ except:
+ pass
+ return maybe
+ return decorate
+
+def supported(*dbs):
+ """Mark a test as supported by one or more database implementations"""
+
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name in dbs:
+ return fn(*args, **kw)
+ else:
+ print "'%s' unsupported on DB implementation '%s'" % (
+ fn_name, config.db.name)
+ return True
+ try:
+ maybe.__name__ = fn_name
+ except:
+ pass
+ return maybe
+ return decorate
+
+class TestData(object):
+ """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
+
+ def __init__(self):
+ self.set_assert_list(None, None)
+ self.sql_count = 0
+ self.buffer = None
+
+ def set_assert_list(self, unittest, list):
+ self.unittest = unittest
+ self.assert_list = list
+ if list is not None:
+ self.assert_list.reverse()
+
+testdata = TestData()
+
+
+class ExecutionContextWrapper(object):
+ """instruments the ExecutionContext created by the Engine so that SQL expressions
+ can be tracked."""
+
+ def __init__(self, ctx):
+ self.__dict__['ctx'] = ctx
+ def __getattr__(self, key):
+ return getattr(self.ctx, key)
+ def __setattr__(self, key, value):
+ setattr(self.ctx, key, value)
+
+ def post_execution(self):
+ ctx = self.ctx
+ statement = unicode(ctx.compiled)
+ statement = re.sub(r'\n', '', ctx.statement)
+ if testdata.buffer is not None:
+ testdata.buffer.write(statement + "\n")
+
+ if testdata.assert_list is not None:
+ assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
+ item = testdata.assert_list[-1]
+ if not isinstance(item, dict):
+ item = testdata.assert_list.pop()
+ else:
+ # asserting a dictionary of statements->parameters
+ # this is to specify query assertions where the queries can be in
+ # multiple orderings
+ if not item.has_key('_converted'):
+ for key in item.keys():
+ ckey = self.convert_statement(key)
+ item[ckey] = item[key]
+ if ckey != key:
+ del item[key]
+ item['_converted'] = True
+ try:
+ entry = item.pop(statement)
+ if len(item) == 1:
+ testdata.assert_list.pop()
+ item = (statement, entry)
+ except KeyError:
+ assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
+
+ (query, params) = item
+ if callable(params):
+ params = params(ctx)
+ if params is not None and isinstance(params, list) and len(params) == 1:
+ params = params[0]
+
+ if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
+ parameters = ctx.compiled_parameters.get_original_dict()
+ elif isinstance(ctx.compiled_parameters, list):
+ parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
+
+ query = self.convert_statement(query)
+ if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'):
+ statement = statement[:-25]
+ testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
+ testdata.sql_count += 1
+ self.ctx.post_execution()
+
+ def convert_statement(self, query):
+ paramstyle = self.ctx.dialect.paramstyle
+ if paramstyle == 'named':
+ pass
+ elif paramstyle =='pyformat':
+ query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+ else:
+ # positional params
+ repl = None
+ if paramstyle=='qmark':
+ repl = "?"
+ elif paramstyle=='format':
+ repl = r"%s"
+ elif paramstyle=='numeric':
+ repl = None
+ query = re.sub(r':([\w_]+)', repl, query)
+ return query
+
+class PersistTest(unittest.TestCase):
+
+ def __init__(self, *args, **params):
+ unittest.TestCase.__init__(self, *args, **params)
+
+ def setUpAll(self):
+ pass
+
+ def tearDownAll(self):
+ pass
+
+ def shortDescription(self):
+ """overridden to not return docstrings"""
+ return None
+
+class AssertMixin(PersistTest):
+ """given a list-based structure of keys/properties which represent information within an object structure, and
+ a list of actual objects, asserts that the list of objects corresponds to the structure."""
+
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print repr(result)
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list):
+ self.assert_(len(result) == len(list),
+ "result list is not the same size as test list, " +
+ "for class " + class_.__name__)
+ for i in range(0, len(list)):
+ self.assert_row(class_, result[i], list[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(rowobj.__class__ is class_,
+ "item class is not " + repr(class_))
+ for key, value in desc.iteritems():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s" % (
+ key, getattr(rowobj, key), value))
+
+ def assert_sql(self, db, callable_, list, with_sequences=None):
+ global testdata
+ testdata = TestData()
+ if with_sequences is not None and (config.db.name == 'postgres' or
+ config.db.name == 'oracle'):
+ testdata.set_assert_list(self, with_sequences)
+ else:
+ testdata.set_assert_list(self, list)
+ try:
+ callable_()
+ finally:
+ testdata.set_assert_list(None, None)
+
+ def assert_sql_count(self, db, callable_, count):
+ global testdata
+ testdata = TestData()
+ try:
+ callable_()
+ finally:
+ self.assert_(testdata.sql_count == count,
+ "desired statement count %d does not match %d" % (
+ count, testdata.sql_count))
+
+ def capture_sql(self, db, callable_):
+ global testdata
+ testdata = TestData()
+ buffer = StringIO()
+ testdata.buffer = buffer
+ try:
+ callable_()
+ return buffer.getvalue()
+ finally:
+ testdata.buffer = None
+
+_otest_metadata = None
+class ORMTest(AssertMixin):
+ keep_mappers = False
+ keep_data = False
+
+ def setUpAll(self):
+ global _otest_metadata
+ _otest_metadata = MetaData(config.db)
+ self.define_tables(_otest_metadata)
+ _otest_metadata.create_all()
+ self.insert_data()
+
+ def define_tables(self, _otest_metadata):
+ raise NotImplementedError()
+
+ def insert_data(self):
+ pass
+
+ def get_metadata(self):
+ return _otest_metadata
+
+ def tearDownAll(self):
+ clear_mappers()
+ _otest_metadata.drop_all()
+
+ def tearDown(self):
+ if not self.keep_mappers:
+ clear_mappers()
+ if not self.keep_data:
+ for t in _otest_metadata.table_iterator(reverse=True):
+ t.delete().execute().close()
+
+
+class TTestSuite(unittest.TestSuite):
+ """A TestSuite with once per TestCase setUpAll() and tearDownAll()"""
+
+ def __init__(self, tests=()):
+ if len(tests) >0 and isinstance(tests[0], PersistTest):
+ self._initTest = tests[0]
+ else:
+ self._initTest = None
+ unittest.TestSuite.__init__(self, tests)
+
+ def do_run(self, result):
+ # nice job unittest ! you switched __call__ and run() between py2.3
+ # and 2.4 thereby making straight subclassing impossible !
+ for test in self._tests:
+ if result.shouldStop:
+ break
+ test(result)
+ return result
+
+ def run(self, result):
+ return self(result)
+
+ def __call__(self, result):
+ try:
+ if self._initTest is not None:
+ self._initTest.setUpAll()
+ except:
+ result.addError(self._initTest, self.__exc_info())
+ pass
+ try:
+ return self.do_run(result)
+ finally:
+ try:
+ if self._initTest is not None:
+ self._initTest.tearDownAll()
+ except:
+ result.addError(self._initTest, self.__exc_info())
+ pass
+
+ def __exc_info(self):
+ """Return a version of sys.exc_info() with the traceback frame
+ minimised; usually the top level of the traceback frame is not
+ needed.
+ ripped off out of unittest module since its double __
+ """
+ exctype, excvalue, tb = sys.exc_info()
+ if sys.platform[:4] == 'java': ## tracebacks look different in Jython
+ return (exctype, excvalue, tb)
+ return (exctype, excvalue, tb)
+
+unittest.TestLoader.suiteClass = TTestSuite
+
+def _iter_covered_files():
+ import sqlalchemy
+ for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
+ for x in rec[2]:
+ if x.endswith('.py'):
+ yield os.path.join(rec[0], x)
+
+def cover(callable_, file_=None):
+ from testlib import coverage
+ coverage_client = coverage.the_coverage
+ coverage_client.get_ready()
+ coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
+ coverage_client.erase()
+ coverage_client.start()
+ try:
+ return callable_()
+ finally:
+ coverage_client.stop()
+ coverage_client.save()
+ coverage_client.report(list(_iter_covered_files()),
+ show_missing=False, ignore_errors=False,
+ file=file_)
+
+class DevNullWriter(object):
+ def write(self, msg):
+ pass
+ def flush(self):
+ pass
+
+def runTests(suite):
+ verbose = config.options.verbose
+ quiet = config.options.quiet
+ orig_stdout = sys.stdout
+
+ try:
+ if not verbose or quiet:
+ sys.stdout = DevNullWriter()
+ runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
+ return runner.run(suite)
+ finally:
+ if not verbose or quiet:
+ sys.stdout = orig_stdout
+
+def main(suite=None):
+ if not suite:
+ if len(sys.argv[1:]):
+ suite =unittest.TestLoader().loadTestsFromNames(
+ sys.argv[1:], __import__('__main__'))
+ else:
+ suite = unittest.TestLoader().loadTestsFromModule(
+ __import__('__main__'))
+
+ result = runTests(suite)
+ sys.exit(not result.wasSuccessful())
diff --git a/test/zblog/mappers.py b/test/zblog/mappers.py
index 244a53d0e..11eaf4fd0 100644
--- a/test/zblog/mappers.py
+++ b/test/zblog/mappers.py
@@ -4,6 +4,7 @@ import zblog.tables as tables
import zblog.user as user
from zblog.blog import *
from sqlalchemy import *
+from sqlalchemy.orm import *
import sqlalchemy.util as util
def zblog_mappers():
diff --git a/test/zblog/tables.py b/test/zblog/tables.py
index f01f18921..5b4054a19 100644
--- a/test/zblog/tables.py
+++ b/test/zblog/tables.py
@@ -1,13 +1,16 @@
+"""application table metadata objects are described here."""
+
from sqlalchemy import *
+from testlib import *
+
metadata = MetaData()
-"""application table metadata objects are described here."""
users = Table('users', metadata,
Column('user_id', Integer, primary_key=True),
Column('user_name', String(30), nullable=False),
Column('fullname', String(100), nullable=False),
- Column('password', String(30), nullable=False),
+ Column('password', String(40), nullable=False),
Column('groupname', String(20), nullable=False),
)
diff --git a/test/zblog/tests.py b/test/zblog/tests.py
index e538cff9d..ad6876937 100644
--- a/test/zblog/tests.py
+++ b/test/zblog/tests.py
@@ -1,20 +1,20 @@
-from testbase import AssertMixin
import testbase
-import unittest
-db = testbase.db
from sqlalchemy import *
-
+from sqlalchemy.orm import *
+from testlib import *
from zblog import mappers, tables
from zblog.user import *
from zblog.blog import *
+
class ZBlogTest(AssertMixin):
def create_tables(self):
- tables.metadata.create_all(connectable=db)
+ tables.metadata.drop_all(bind=testbase.db)
+ tables.metadata.create_all(bind=testbase.db)
def drop_tables(self):
- tables.metadata.drop_all(connectable=db)
+ tables.metadata.drop_all(bind=testbase.db)
def setUpAll(self):
self.create_tables()
@@ -31,7 +31,7 @@ class SavePostTest(ZBlogTest):
super(SavePostTest, self).setUpAll()
mappers.zblog_mappers()
global blog_id, user_id
- s = create_session(bind_to=db)
+ s = create_session(bind=testbase.db)
user = User('zbloguser', "Zblog User", "hello", group=administrator)
blog = Blog(owner=user)
blog.name = "this is a blog"
@@ -50,9 +50,9 @@ class SavePostTest(ZBlogTest):
"""test that a transient/pending instance has proper bi-directional behavior.
this requires that lazy loaders do not fire off for a transient/pending instance."""
- s = create_session(bind_to=db)
+ s = create_session(bind=testbase.db)
- trans = s.create_transaction()
+ s.begin()
try:
blog = s.query(Blog).get(blog_id)
post = Post(headline="asdf asdf", summary="asdfasfd")
@@ -61,14 +61,14 @@ class SavePostTest(ZBlogTest):
post.blog = blog
assert post in blog.posts
finally:
- trans.rollback()
+ s.rollback()
def testoptimisticorphans(self):
"""test that instances in the session with un-loaded parents will not
get marked as "orphans" and then deleted """
- s = create_session(bind_to=db)
+ s = create_session(bind=testbase.db)
- trans = s.create_transaction()
+ s.begin()
try:
blog = s.query(Blog).get(blog_id)
post = Post(headline="asdf asdf", summary="asdfasfd")
@@ -90,10 +90,10 @@ class SavePostTest(ZBlogTest):
assert s.query(Post).get(post.id) is not None
finally:
- trans.rollback()
+ s.rollback()
if __name__ == "__main__":
testbase.main()
- \ No newline at end of file
+
diff --git a/test/zblog/user.py b/test/zblog/user.py
index 1dca0328e..3e77fa842 100644
--- a/test/zblog/user.py
+++ b/test/zblog/user.py
@@ -1,13 +1,7 @@
"""user.py - handles user login and validation"""
import random, string
-try:
- from crypt import crypt
-except:
- try:
- from fcrypt import crypt
- except:
- raise "Need fcrypt module on non-Unix platform: http://home.clear.net.nz/pages/c.evans/sw/"
+from sha import sha
administrator = 'admin'
user = 'user'
@@ -16,7 +10,7 @@ groups = [user, administrator]
def cryptpw(password, salt=None):
if salt is None:
salt = string.join([chr(random.randint(ord('a'), ord('z'))), chr(random.randint(ord('a'), ord('z')))],'')
- return crypt(password, salt)
+ return sha(password + salt).hexdigest()
def checkpw(password, dbpw):
return cryptpw(password, dbpw[:2]) == dbpw