diff options
author | jonathan vanasco <jonathan@2xlp.com> | 2015-04-02 13:30:26 -0400 |
---|---|---|
committer | jonathan vanasco <jonathan@2xlp.com> | 2015-04-02 13:30:26 -0400 |
commit | 6de3d490a2adb0fff43f98e15a53407b46668b61 (patch) | |
tree | d5e0e2077dfe7dc69ce30e9d0a8c89ceff78e3fe /lib/sqlalchemy | |
parent | efca4af93603faa7abfeacbab264cad85ee4105c (diff) | |
parent | 5e04995a82c00e801a99765cde7726f5e73e18c2 (diff) | |
download | sqlalchemy-6de3d490a2adb0fff43f98e15a53407b46668b61.tar.gz |
Merge branch 'master' of bitbucket.org:zzzeek/sqlalchemy
Diffstat (limited to 'lib/sqlalchemy')
173 files changed, 7410 insertions, 4403 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index d184e1fbf..709ba3246 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -1,5 +1,5 @@ # sqlalchemy/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -120,7 +120,7 @@ from .schema import ( from .inspection import inspect from .engine import create_engine, engine_from_config -__version__ = '1.0.0' +__version__ = '1.0.0b5' def __go(lcls): diff --git a/lib/sqlalchemy/cextension/processors.c b/lib/sqlalchemy/cextension/processors.c index d56817763..59eb2648c 100644 --- a/lib/sqlalchemy/cextension/processors.c +++ b/lib/sqlalchemy/cextension/processors.c @@ -1,6 +1,6 @@ /* processors.c -Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file> +Copyright (C) 2010-2015 the SQLAlchemy authors and contributors <see AUTHORS file> Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c index 218c7b807..ae2a059cf 100644 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -1,6 +1,6 @@ /* resultproxy.c -Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file> +Copyright (C) 2010-2015 the SQLAlchemy authors and contributors <see AUTHORS file> Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c index 377ba8a8d..6e00eb816 100644 --- a/lib/sqlalchemy/cextension/utils.c +++ b/lib/sqlalchemy/cextension/utils.c @@ -1,6 +1,6 @@ /* utils.c -Copyright (C) 2012-2014 the SQLAlchemy authors and contributors <see AUTHORS file> +Copyright (C) 2012-2015 the SQLAlchemy authors and contributors <see AUTHORS file> This module is part of SQLAlchemy and is released under the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py index 9253a21d5..5f65b9306 100644 --- a/lib/sqlalchemy/connectors/__init__.py +++ b/lib/sqlalchemy/connectors/__init__.py @@ -1,5 +1,5 @@ # connectors/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py index 851dc11e8..1bbf899c4 100644 --- a/lib/sqlalchemy/connectors/mxodbc.py +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -1,5 +1,5 @@ # connectors/mxodbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 907e4d353..84bc92bee 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -1,5 +1,5 @@ # connectors/pyodbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py index c0af742fa..8219a06eb 100644 --- a/lib/sqlalchemy/connectors/zxJDBC.py +++ b/lib/sqlalchemy/connectors/zxJDBC.py @@ -1,5 +1,5 @@ # connectors/zxJDBC.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 356fbec59..321ff999b 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -1,5 +1,5 @@ # databases/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 74c48820d..d90a83809 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -1,5 +1,5 @@ # dialects/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py index 9e8a88245..b2fb57a63 100644 --- a/lib/sqlalchemy/dialects/firebird/__init__.py +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -1,5 +1,5 @@ # firebird/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 36229a105..9d8630d3c 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -1,5 +1,5 @@ # firebird/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -180,16 +180,16 @@ ischema_names = { # _FBDate, etc. as bind/result functionality is required) class FBTypeCompiler(compiler.GenericTypeCompiler): - def visit_boolean(self, type_): - return self.visit_SMALLINT(type_) + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_datetime(self, type_): - return self.visit_TIMESTAMP(type_) + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return "BLOB SUB_TYPE 1" - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB SUB_TYPE 0" def _extend_string(self, type_, basic): @@ -199,16 +199,16 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): else: return '%s CHARACTER SET %s' % (basic, charset) - def visit_CHAR(self, type_): - basic = super(FBTypeCompiler, self).visit_CHAR(type_) + def visit_CHAR(self, type_, **kw): + basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) return self._extend_string(type_, basic) - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): if not type_.length: raise exc.CompileError( "VARCHAR requires a length on dialect %s" % self.dialect.name) - basic = super(FBTypeCompiler, self).visit_VARCHAR(type_) + basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) return self._extend_string(type_, basic) @@ -394,6 +394,8 @@ class FBDialect(default.DefaultDialect): requires_name_normalize = True supports_empty_insert = False + supports_simple_order_by_label = False + statement_compiler = FBCompiler ddl_compiler = FBDDLCompiler preparer = FBIdentifierPreparer diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py index ddffc80f5..0ab07498b 100644 --- a/lib/sqlalchemy/dialects/firebird/fdb.py +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -1,5 +1,5 @@ # firebird/fdb.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py index 6bd7887f7..7d1a834b8 100644 --- a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -1,5 +1,5 @@ # firebird/kinterbasdb.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index d0047765e..898b40cd5 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -1,5 +1,5 @@ # mssql/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py index e9927f8ed..6e3f348fc 100644 --- a/lib/sqlalchemy/dialects/mssql/adodbapi.py +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -1,5 +1,5 @@ # mssql/adodbapi.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index dad02ee0f..26b794712 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1,5 +1,5 @@ # mssql/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -226,6 +226,53 @@ The DATE and TIME types are not available for MSSQL 2005 and previous - if a server version below 2008 is detected, DDL for these types will be issued as DATETIME. +.. _mssql_large_type_deprecation: + +Large Text/Binary Type Deprecation +---------------------------------- + +Per `SQL Server 2012/2014 Documentation <http://technet.microsoft.com/en-us/library/ms187993.aspx>`_, +the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL Server +in a future release. SQLAlchemy normally relates these types to the +:class:`.UnicodeText`, :class:`.Text` and :class:`.LargeBinary` datatypes. + +In order to accommodate this change, a new flag ``deprecate_large_types`` +is added to the dialect, which will be automatically set based on detection +of the server version in use, if not otherwise set by the user. The +behavior of this flag is as follows: + +* When this flag is ``True``, the :class:`.UnicodeText`, :class:`.Text` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``, + respectively. This is a new behavior as of the addition of this flag. + +* When this flag is ``False``, the :class:`.UnicodeText`, :class:`.Text` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NTEXT``, ``TEXT``, and ``IMAGE``, + respectively. This is the long-standing behavior of these types. + +* The flag begins with the value ``None``, before a database connection is + established. If the dialect is used to render DDL without the flag being + set, it is interpreted the same as ``False``. + +* On first connection, the dialect detects if SQL Server version 2012 or greater + is in use; if the flag is still at ``None``, it sets it to ``True`` or + ``False`` based on whether 2012 or greater is detected. + +* The flag can be set to either ``True`` or ``False`` when the dialect + is created, typically via :func:`.create_engine`:: + + eng = create_engine("mssql+pymssql://user:pass@host/db", + deprecate_large_types=True) + +* Complete control over whether the "old" or "new" types are rendered is + available in all SQLAlchemy versions by using the UPPERCASE type objects + instead: :class:`.NVARCHAR`, :class:`.VARCHAR`, :class:`.types.VARBINARY`, + :class:`.TEXT`, :class:`.mssql.NTEXT`, :class:`.mssql.IMAGE` will always remain + fixed and always output exactly that type. + +.. versionadded:: 1.0.0 + .. _mssql_indexes: Clustered Index Support @@ -367,19 +414,20 @@ import operator import re from ... import sql, schema as sa_schema, exc, util -from ...sql import compiler, expression, \ - util as sql_util, cast +from ...sql import compiler, expression, util as sql_util from ... import engine from ...engine import reflection, default from ... import types as sqltypes from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ - VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR + TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR from ...util import update_wrapper from . import information_schema as ischema +# http://sqlserverbuilds.blogspot.com/ +MS_2012_VERSION = (11,) MS_2008_VERSION = (10,) MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) @@ -545,6 +593,26 @@ class NTEXT(sqltypes.UnicodeText): __visit_name__ = 'NTEXT' +class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): + """The MSSQL VARBINARY type. + + This type extends both :class:`.types.VARBINARY` and + :class:`.types.LargeBinary`. In "deprecate_large_types" mode, + the :class:`.types.LargeBinary` type will produce ``VARBINARY(max)`` + on SQL Server. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :ref:`mssql_large_type_deprecation` + + + + """ + __visit_name__ = 'VARBINARY' + + class IMAGE(sqltypes.LargeBinary): __visit_name__ = 'IMAGE' @@ -626,7 +694,6 @@ ischema_names = { class MSTypeCompiler(compiler.GenericTypeCompiler): - def _extend(self, spec, type_, length=None): """Extend a string-type declaration with standard SQL COLLATE annotations. @@ -647,103 +714,115 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return ' '.join([c for c in (spec, collation) if c is not None]) - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision is None: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': precision} - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): return "TINYINT" - def visit_DATETIMEOFFSET(self, type_): + def visit_DATETIMEOFFSET(self, type_, **kw): if type_.precision: return "DATETIMEOFFSET(%s)" % type_.precision else: return "DATETIMEOFFSET" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision: return "TIME(%s)" % precision else: return "TIME" - def visit_DATETIME2(self, type_): + def visit_DATETIME2(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision: return "DATETIME2(%s)" % precision else: return "DATETIME2" - def visit_SMALLDATETIME(self, type_): + def visit_SMALLDATETIME(self, type_, **kw): return "SMALLDATETIME" - def visit_unicode(self, type_): - return self.visit_NVARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_NVARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARCHAR(type_, **kw) + else: + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): - return self.visit_NTEXT(type_) + def visit_unicode_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_NVARCHAR(type_, **kw) + else: + return self.visit_NTEXT(type_, **kw) - def visit_NTEXT(self, type_): + def visit_NTEXT(self, type_, **kw): return self._extend("NTEXT", type_) - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._extend("TEXT", type_) - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._extend("VARCHAR", type_, length=type_.length or 'max') - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._extend("CHAR", type_) - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._extend("NCHAR", type_) - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._extend("NVARCHAR", type_, length=type_.length or 'max') - def visit_date(self, type_): + def visit_date(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: - return self.visit_DATETIME(type_) + return self.visit_DATETIME(type_, **kw) else: - return self.visit_DATE(type_) + return self.visit_DATE(type_, **kw) - def visit_time(self, type_): + def visit_time(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: - return self.visit_DATETIME(type_) + return self.visit_DATETIME(type_, **kw) else: - return self.visit_TIME(type_) + return self.visit_TIME(type_, **kw) - def visit_large_binary(self, type_): - return self.visit_IMAGE(type_) + def visit_large_binary(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARBINARY(type_, **kw) + else: + return self.visit_IMAGE(type_, **kw) - def visit_IMAGE(self, type_): + def visit_IMAGE(self, type_, **kw): return "IMAGE" - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return self._extend( "VARBINARY", type_, length=type_.length or 'max') - def visit_boolean(self, type_): + def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): return "BIT" - def visit_MONEY(self, type_): + def visit_MONEY(self, type_, **kw): return "MONEY" - def visit_SMALLMONEY(self, type_): + def visit_SMALLMONEY(self, type_, **kw): return 'SMALLMONEY' - def visit_UNIQUEIDENTIFIER(self, type_): + def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" - def visit_SQL_VARIANT(self, type_): + def visit_SQL_VARIANT(self, type_, **kw): return 'SQL_VARIANT' @@ -952,6 +1031,7 @@ class MSSQLCompiler(compiler.SQLCompiler): _order_by_clauses = select._order_by_clause.clauses limit_clause = select._limit_clause offset_clause = select._offset_clause + kwargs['select_wraps_for'] = select select = select._generate() select._mssql_visit = True select = select.column( @@ -969,7 +1049,7 @@ class MSSQLCompiler(compiler.SQLCompiler): else: limitselect.append_whereclause( mssql_rn <= (limit_clause)) - return self.process(limitselect, iswrapper=True, **kwargs) + return self.process(limitselect, **kwargs) else: return compiler.SQLCompiler.visit_select(self, select, **kwargs) @@ -1160,8 +1240,11 @@ class MSSQLStrictCompiler(MSSQLCompiler): class MSDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = (self.preparer.format_column(column) + " " - + self.dialect.type_compiler.process(column.type)) + colspec = ( + self.preparer.format_column(column) + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column) + ) if column.nullable is not None: if not column.nullable or column.primary_key or \ @@ -1333,6 +1416,7 @@ class MSDialect(default.DefaultDialect): use_scope_identity = True max_identifier_length = 128 schema_name = "dbo" + supports_simple_order_by_label = False colspecs = { sqltypes.DateTime: _MSDateTime, @@ -1370,13 +1454,15 @@ class MSDialect(default.DefaultDialect): query_timeout=None, use_scope_identity=True, max_identifier_length=None, - schema_name="dbo", **opts): + schema_name="dbo", + deprecate_large_types=None, **opts): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = use_scope_identity self.max_identifier_length = int(max_identifier_length or 0) or \ self.max_identifier_length + self.deprecate_large_types = deprecate_large_types super(MSDialect, self).__init__(**opts) def do_savepoint(self, connection, name): @@ -1390,6 +1476,9 @@ class MSDialect(default.DefaultDialect): def initialize(self, connection): super(MSDialect, self).initialize(connection) + self._setup_version_attributes() + + def _setup_version_attributes(self): if self.server_version_info[0] not in list(range(8, 17)): # FreeTDS with version 4.2 seems to report here # a number like "95.10.255". Don't know what @@ -1405,6 +1494,9 @@ class MSDialect(default.DefaultDialect): self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True + if self.deprecate_large_types is None: + self.deprecate_large_types = \ + self.server_version_info >= MS_2012_VERSION def _get_default_schema_name(self, connection): if self.server_version_info < MS_2005_VERSION: @@ -1592,12 +1684,11 @@ class MSDialect(default.DefaultDialect): if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary): + if charlen == -1: + charlen = 'max' kwargs['length'] = charlen if collation: kwargs['collation'] = collation - if coltype == MSText or \ - (coltype in (MSString, MSNVarchar) and charlen == -1): - kwargs.pop('length') if coltype is None: util.warn( diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 371a1edcc..a6faa7bca 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -1,5 +1,5 @@ # mssql/information_schema.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index ffe38d8dd..ac87c67a9 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -1,5 +1,5 @@ # mssql/mxodbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index b5a1bc566..2214d18d1 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -1,5 +1,5 @@ # mssql/pymssql.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 445584d24..ad1e7ae37 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -1,5 +1,5 @@ # mssql/pyodbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/zxjdbc.py b/lib/sqlalchemy/dialects/mssql/zxjdbc.py index b23a010e7..85539817e 100644 --- a/lib/sqlalchemy/dialects/mssql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mssql/zxjdbc.py @@ -1,5 +1,5 @@ # mssql/zxjdbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -13,6 +13,8 @@ [?key=value&key=value...] :driverurl: http://jtds.sourceforge.net/ + .. note:: Jython is not supported by current versions of SQLAlchemy. The + zxjdbc dialect should be considered as experimental. """ from ...connectors.zxJDBC import ZxJDBCConnector diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index 498603cf7..c1f78bd1d 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -1,5 +1,5 @@ # mysql/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 2fb054d0c..8460ff92a 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1,5 +1,5 @@ # mysql/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -106,7 +106,7 @@ to be used. Transaction Isolation Level --------------------------- -:func:`.create_engine` accepts an ``isolation_level`` +:func:`.create_engine` accepts an :paramref:`.create_engine.isolation_level` parameter which results in the command ``SET SESSION TRANSACTION ISOLATION LEVEL <level>`` being invoked for every new connection. Valid values for this parameter are @@ -146,6 +146,90 @@ multi-column key for some storage engines:: Column('id', Integer, primary_key=True) ) +.. _mysql_unicode: + +Unicode +------- + +Charset Selection +~~~~~~~~~~~~~~~~~ + +Most MySQL DBAPIs offer the option to set the client character set for +a connection. This is typically delivered using the ``charset`` parameter +in the URL, such as:: + + e = create_engine("mysql+pymysql://scott:tiger@localhost/\ +test?charset=utf8") + +This charset is the **client character set** for the connection. Some +MySQL DBAPIs will default this to a value such as ``latin1``, and some +will make use of the ``default-character-set`` setting in the ``my.cnf`` +file as well. Documentation for the DBAPI in use should be consulted +for specific behavior. + +The encoding used for Unicode has traditionally been ``'utf8'``. However, +for MySQL versions 5.5.3 on forward, a new MySQL-specific encoding +``'utf8mb4'`` has been introduced. The rationale for this new encoding +is due to the fact that MySQL's utf-8 encoding only supports +codepoints up to three bytes instead of four. Therefore, +when communicating with a MySQL database +that includes codepoints more than three bytes in size, +this new charset is preferred, if supported by both the database as well +as the client DBAPI, as in:: + + e = create_engine("mysql+pymysql://scott:tiger@localhost/\ +test?charset=utf8mb4") + +At the moment, up-to-date versions of MySQLdb and PyMySQL support the +``utf8mb4`` charset. Other DBAPIs such as MySQL-Connector and OurSQL +may **not** support it as of yet. + +In order to use ``utf8mb4`` encoding, changes to +the MySQL schema and/or server configuration may be required. + +.. seealso:: + + `The utf8mb4 Character Set \ +<http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html>`_ - \ +in the MySQL documentation + +Unicode Encoding / Decoding +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +All modern MySQL DBAPIs all offer the service of handling the encoding and +decoding of unicode data between the Python application space and the database. +As this was not always the case, SQLAlchemy also includes a comprehensive system +of performing the encode/decode task as well. As only one of these systems +should be in use at at time, SQLAlchemy has long included functionality +to automatically detect upon first connection whether or not the DBAPI is +automatically handling unicode. + +Whether or not the MySQL DBAPI will handle encoding can usually be configured +using a DBAPI flag ``use_unicode``, which is known to be supported at least +by MySQLdb, PyMySQL, and MySQL-Connector. Setting this value to ``0`` +in the "connect args" or query string will have the effect of disabling the +DBAPI's handling of unicode, such that it instead will return data of the +``str`` type or ``bytes`` type, with data in the configured charset:: + + # connect while disabling the DBAPI's unicode encoding/decoding + e = create_engine("mysql+mysqldb://scott:tiger@localhost/test?charset=utf8&use_unicode=0") + +Current recommendations for modern DBAPIs are as follows: + +* It is generally always safe to leave the ``use_unicode`` flag set at + its default; that is, don't use it at all. +* Under Python 3, the ``use_unicode=0`` flag should **never be used**. + SQLAlchemy under Python 3 generally assumes the DBAPI receives and returns + string values as Python 3 strings, which are inherently unicode objects. +* Under Python 2 with MySQLdb, the ``use_unicode=0`` flag will **offer + superior performance**, as MySQLdb's unicode converters under Python 2 only + have been observed to have unusually slow performance compared to SQLAlchemy's + fast C-based encoders/decoders. + +In short: don't specify ``use_unicode`` *at all*, with the possible +exception of ``use_unicode=0`` on MySQLdb with Python 2 **only** for a +potential performance gain. + Ansi Quoting Style ------------------ @@ -370,10 +454,11 @@ collection. TIMESTAMP Columns and NULL -------------------------- -MySQL enforces that a column which specifies the TIMESTAMP datatype implicitly -includes a default value of CURRENT_TIMESTAMP, even though this is not -stated, and additionally sets the column as NOT NULL, the opposite behavior -vs. that of all other datatypes:: +MySQL historically enforces that a column which specifies the +TIMESTAMP datatype implicitly includes a default value of +CURRENT_TIMESTAMP, even though this is not stated, and additionally +sets the column as NOT NULL, the opposite behavior vs. that of all +other datatypes:: mysql> CREATE TABLE ts_test ( -> a INTEGER, @@ -400,22 +485,29 @@ with NOT NULL. But when the column is of type TIMESTAMP, an implicit default of CURRENT_TIMESTAMP is generated which also coerces the column to be a NOT NULL, even though we did not specify it as such. -Therefore, the usual "NOT NULL" clause *does not apply* to a TIMESTAMP -column; MySQL selects this implicitly. SQLAlchemy therefore does not render -NOT NULL for a TIMESTAMP column on MySQL. However, it *does* render -NULL when we specify nullable=True, or if we leave nullable absent, as it -also defaults to True. This is to accommodate the essentially -reverse behavior of the NULL flag for TIMESTAMP:: +This behavior of MySQL can be changed on the MySQL side using the +`explicit_defaults_for_timestamp +<http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html +#sysvar_explicit_defaults_for_timestamp>`_ configuration flag introduced in +MySQL 5.6. With this server setting enabled, TIMESTAMP columns behave like +any other datatype on the MySQL side with regards to defaults and nullability. + +However, to accommodate the vast majority of MySQL databases that do not +specify this new flag, SQLAlchemy emits the "NULL" specifier explicitly with +any TIMESTAMP column that does not specify ``nullable=False``. In order +to accommodate newer databases that specify ``explicit_defaults_for_timestamp``, +SQLAlchemy also emits NOT NULL for TIMESTAMP columns that do specify +``nullable=False``. The following example illustrates:: - from sqlalchemy import MetaData, TIMESTAMP, Integer, Table, Column, text + from sqlalchemy import MetaData, Integer, Table, Column, text + from sqlalchemy.dialects.mysql import TIMESTAMP m = MetaData() t = Table('ts_test', m, Column('a', Integer), Column('b', Integer, nullable=False), Column('c', TIMESTAMP), - Column('d', TIMESTAMP, nullable=False), - Column('e', TIMESTAMP, nullable=True) + Column('d', TIMESTAMP, nullable=False) ) @@ -423,35 +515,19 @@ reverse behavior of the NULL flag for TIMESTAMP:: e = create_engine("mysql://scott:tiger@localhost/test", echo=True) m.create_all(e) -In the output, we can see that the TIMESTAMP column receives a different -treatment for NULL / NOT NULL vs. that of the INTEGER:: +output:: CREATE TABLE ts_test ( a INTEGER, b INTEGER NOT NULL, c TIMESTAMP NULL, - d TIMESTAMP, - e TIMESTAMP NULL + d TIMESTAMP NOT NULL ) -MySQL above receives the NULL/NOT NULL constraint as is stated in our -original :class:`.Table`:: - - mysql> SHOW CREATE TABLE ts_test; - +---------+--------------------------- - | Table | Create Table - +---------+--------------------------- - | ts_test | CREATE TABLE `ts_test` ( - `a` int(11) DEFAULT NULL, - `b` int(11) NOT NULL, - `c` timestamp NULL DEFAULT NULL, - `d` timestamp NOT NULL DEFAULT '0000-00-00 00:00:00', - `e` timestamp NULL DEFAULT NULL - ) ENGINE=MyISAM DEFAULT CHARSET=latin1 - -Be sure to always favor the ``SHOW CREATE TABLE`` output over the -SQLAlchemy-emitted DDL when checking table definitions, as MySQL's -rules can be hard to predict. +.. versionchanged:: 1.0.0 - SQLAlchemy now renders NULL or NOT NULL in all + cases for TIMESTAMP columns, to accommodate + ``explicit_defaults_for_timestamp``. Prior to this version, it will + not render "NOT NULL" for a TIMESTAMP column that is ``nullable=False``. """ @@ -602,6 +678,14 @@ class _StringType(sqltypes.String): to_inspect=[_StringType, sqltypes.String]) +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): + # TODO: float arguments? + sqltypes.Float.__init__(self) + sqltypes.MatchType.__init__(self) + + + class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" @@ -881,7 +965,9 @@ class BIT(sqltypes.TypeEngine): def process(value): if value is not None: v = 0 - for i in map(ord, value): + for i in value: + if not isinstance(i, int): + i = ord(i) # convert byte to int on Python 2 v = v << 8 | i return v return value @@ -1382,6 +1468,7 @@ class ENUM(sqltypes.Enum, _EnumeratedValues): kw.pop('quote', None) kw.pop('native_enum', None) kw.pop('inherit_schema', None) + kw.pop('_create_events', None) _StringType.__init__(self, length=length, **kw) sqltypes.Enum.__init__(self, *values) @@ -1420,32 +1507,28 @@ class SET(_EnumeratedValues): Column('myset', SET("foo", "bar", "baz")) - :param values: The range of valid values for this SET. Values will be - quoted when generating the schema according to the quoting flag (see - below). - .. versionchanged:: 0.9.0 quoting is applied automatically to - :class:`.mysql.SET` in the same way as for :class:`.mysql.ENUM`. + The list of potential values is required in the case that this + set will be used to generate DDL for a table, or if the + :paramref:`.SET.retrieve_as_bitwise` flag is set to True. - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. + :param values: The range of valid values for this SET. - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. + :param convert_unicode: Same flag as that of + :paramref:`.String.convert_unicode`. - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. + :param collation: same as that of :paramref:`.String.collation` - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. + :param charset: same as that of :paramref:`.VARCHAR.charset`. - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. + :param ascii: same as that of :paramref:`.VARCHAR.ascii`. - :param quoting: Defaults to 'auto': automatically determine enum value - quoting. If all enum values are surrounded by the same quoting + :param unicode: same as that of :paramref:`.VARCHAR.unicode`. + + :param binary: same as that of :paramref:`.VARCHAR.binary`. + + :param quoting: Defaults to 'auto': automatically determine set value + quoting. If all values are surrounded by the same quoting character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. 'quoted': values in enums are already quoted, they will be used @@ -1460,50 +1543,117 @@ class SET(_EnumeratedValues): .. versionadded:: 0.9.0 + :param retrieve_as_bitwise: if True, the data for the set type will be + persisted and selected using an integer value, where a set is coerced + into a bitwise mask for persistence. MySQL allows this mode which + has the advantage of being able to store values unambiguously, + such as the blank string ``''``. The datatype will appear + as the expression ``col + 0`` in a SELECT statement, so that the + value is coerced into an integer value in result sets. + This flag is required if one wishes + to persist a set that can store the blank string ``''`` as a value. + + .. warning:: + + When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is + essential that the list of set values is expressed in the + **exact same order** as exists on the MySQL database. + + .. versionadded:: 1.0.0 + + """ + self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False) values, length = self._init_values(values, kw) self.values = tuple(values) - + if not self.retrieve_as_bitwise and '' in values: + raise exc.ArgumentError( + "Can't use the blank value '' in a SET without " + "setting retrieve_as_bitwise=True") + if self.retrieve_as_bitwise: + self._bitmap = dict( + (value, 2 ** idx) + for idx, value in enumerate(self.values) + ) + self._bitmap.update( + (2 ** idx, value) + for idx, value in enumerate(self.values) + ) kw.setdefault('length', length) super(SET, self).__init__(**kw) + def column_expression(self, colexpr): + if self.retrieve_as_bitwise: + return colexpr + 0 + else: + return colexpr + def result_processor(self, dialect, coltype): - def process(value): - # The good news: - # No ',' quoting issues- commas aren't allowed in SET values - # The bad news: - # Plenty of driver inconsistencies here. - if isinstance(value, set): - # ..some versions convert '' to an empty set - if not value: - value.add('') - return value - # ...and some versions return strings - if value is not None: - return set(value.split(',')) - else: - return value + if self.retrieve_as_bitwise: + def process(value): + if value is not None: + value = int(value) + + return set( + util.map_bits(self._bitmap.__getitem__, value) + ) + else: + return None + else: + super_convert = super(SET, self).result_processor(dialect, coltype) + + def process(value): + if isinstance(value, util.string_types): + # MySQLdb returns a string, let's parse + if super_convert: + value = super_convert(value) + return set(re.findall(r'[^,]+', value)) + else: + # mysql-connector-python does a naive + # split(",") which throws in an empty string + if value is not None: + value.discard('') + return value return process def bind_processor(self, dialect): super_convert = super(SET, self).bind_processor(dialect) + if self.retrieve_as_bitwise: + def process(value): + if value is None: + return None + elif isinstance(value, util.int_types + util.string_types): + if super_convert: + return super_convert(value) + else: + return value + else: + int_value = 0 + for v in value: + int_value |= self._bitmap[v] + return int_value + else: - def process(value): - if value is None or isinstance( - value, util.int_types + util.string_types): - pass - else: - if None in value: - value = set(value) - value.remove(None) - value.add('') - value = ','.join(value) - if super_convert: - return super_convert(value) - else: - return value + def process(value): + # accept strings and int (actually bitflag) values directly + if value is not None and not isinstance( + value, util.int_types + util.string_types): + value = ",".join(value) + + if super_convert: + return super_convert(value) + else: + return value return process + def adapt(self, impltype, **kw): + kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise + return util.constructor_copy( + self, impltype, + *self.values, + **kw + ) + # old names MSTime = TIME MSSet = SET @@ -1544,6 +1694,7 @@ colspecs = { sqltypes.Float: FLOAT, sqltypes.Time: TIME, sqltypes.Enum: ENUM, + sqltypes.MatchType: _MatchType } # Everything 3.23 through 5.1 excepting OpenGIS types. @@ -1619,9 +1770,12 @@ class MySQLCompiler(compiler.SQLCompiler): def get_from_hint_text(self, table, text): return text - def visit_typeclause(self, typeclause): - type_ = typeclause.type.dialect_impl(self.dialect) - if isinstance(type_, sqltypes.Integer): + def visit_typeclause(self, typeclause, type_=None): + if type_ is None: + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.TypeDecorator): + return self.visit_typeclause(typeclause, type_.impl) + elif isinstance(type_, sqltypes.Integer): if getattr(type_, 'unsigned', False): return 'UNSIGNED INTEGER' else: @@ -1646,10 +1800,17 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_cast(self, cast, **kwargs): # No cast until 4, no decimals until 5. if not self.dialect._supports_cast: + util.warn( + "Current MySQL version does not support " + "CAST; the CAST will be skipped.") return self.process(cast.clause.self_group()) type_ = self.process(cast.typeclause) if type_ is None: + util.warn( + "Datatype %s does not support CAST on MySQL; " + "the CAST will be skipped." % + self.dialect.type_compiler.process(cast.typeclause.type)) return self.process(cast.clause.self_group()) return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) @@ -1758,10 +1919,10 @@ class MySQLCompiler(compiler.SQLCompiler): # creation of foreign key constraints fails." class MySQLDDLCompiler(compiler.DDLCompiler): - def create_table_constraints(self, table): + def create_table_constraints(self, table, **kw): """Get table constraints.""" constraint_string = super( - MySQLDDLCompiler, self).create_table_constraints(table) + MySQLDDLCompiler, self).create_table_constraints(table, **kw) # why self.dialect.name and not 'mysql'? because of drizzle is_innodb = 'engine' in table.dialect_options[self.dialect.name] and \ @@ -1787,24 +1948,28 @@ class MySQLDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kw): """Builds column DDL.""" - colspec = [self.preparer.format_column(column), - self.dialect.type_compiler.process(column.type) - ] - - default = self.get_column_default_string(column) - if default is not None: - colspec.append('DEFAULT ' + default) + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler.process( + column.type, type_expression=column) + ] is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP) - if not column.nullable and not is_timestamp: + + if not column.nullable: colspec.append('NOT NULL') # see: http://docs.sqlalchemy.org/en/latest/dialects/ # mysql.html#mysql_timestamp_null - elif column.nullable and is_timestamp and default is None: + elif column.nullable and is_timestamp: colspec.append('NULL') - if column is column.table._autoincrement_column and \ + default = self.get_column_default_string(column) + if default is not None: + colspec.append('DEFAULT ' + default) + + if column.table is not None \ + and column is column.table._autoincrement_column and \ column.server_default is None: colspec.append('AUTO_INCREMENT') @@ -1987,7 +2152,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: @@ -2000,7 +2165,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: @@ -2013,7 +2178,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DOUBLE(self, type_): + def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % @@ -2022,7 +2187,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'DOUBLE') - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % @@ -2031,7 +2196,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'REAL') - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): if self._mysql_type(type_) and \ type_.scale is not None and \ type_.precision is not None: @@ -2043,7 +2208,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "FLOAT") - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "INTEGER(%(display_width)s)" % @@ -2051,7 +2216,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "INTEGER") - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "BIGINT(%(display_width)s)" % @@ -2059,7 +2224,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "BIGINT") - def visit_MEDIUMINT(self, type_): + def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "MEDIUMINT(%(display_width)s)" % @@ -2067,14 +2232,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "MEDIUMINT") - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) else: return self._extend_numeric(type_, "TINYINT") - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % @@ -2083,55 +2248,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "SMALLINT") - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): if type_.length is not None: return "BIT(%s)" % type_.length else: return "BIT" - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIME(%d)" % type_.fsp else: return "TIME" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" - def visit_YEAR(self, type_): + def visit_YEAR(self, type_, **kw): if type_.display_width is None: return "YEAR" else: return "YEAR(%s)" % type_.display_width - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) else: return self._extend_string(type_, {}, "TEXT") - def visit_TINYTEXT(self, type_): + def visit_TINYTEXT(self, type_, **kw): return self._extend_string(type_, {}, "TINYTEXT") - def visit_MEDIUMTEXT(self, type_): + def visit_MEDIUMTEXT(self, type_, **kw): return self._extend_string(type_, {}, "MEDIUMTEXT") - def visit_LONGTEXT(self, type_): + def visit_LONGTEXT(self, type_, **kw): return self._extend_string(type_, {}, "LONGTEXT") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): if type_.length: return self._extend_string( type_, {}, "VARCHAR(%d)" % type_.length) @@ -2140,14 +2305,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "VARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "CHAR(%(length)s)" % {'length': type_.length}) else: return self._extend_string(type_, {}, "CHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". if type_.length: @@ -2159,7 +2324,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "NVARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: @@ -2169,31 +2334,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_string(type_, {'national': True}, "CHAR") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) - def visit_enum(self, type_): + def visit_enum(self, type_, **kw): if not type_.native_enum: return super(MySQLTypeCompiler, self).visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): if type_.length: return "BLOB(%d)" % type_.length else: return "BLOB" - def visit_TINYBLOB(self, type_): + def visit_TINYBLOB(self, type_, **kw): return "TINYBLOB" - def visit_MEDIUMBLOB(self, type_): + def visit_MEDIUMBLOB(self, type_, **kw): return "MEDIUMBLOB" - def visit_LONGBLOB(self, type_): + def visit_LONGBLOB(self, type_, **kw): return "LONGBLOB" def _visit_enumerated_values(self, name, type_, enumerated_values): @@ -2204,15 +2369,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): name, ",".join(quoted_enums)) ) - def visit_ENUM(self, type_): + def visit_ENUM(self, type_, **kw): return self._visit_enumerated_values("ENUM", type_, type_._enumerated_values) - def visit_SET(self, type_): + def visit_SET(self, type_, **kw): return self._visit_enumerated_values("SET", type_, type_._enumerated_values) - def visit_BOOLEAN(self, type): + def visit_BOOLEAN(self, type, **kw): return "BOOL" @@ -2593,7 +2758,7 @@ class MySQLDialect(default.DefaultDialect): pass else: self.logger.info( - "Converting unknown KEY type %s to a plain KEY" % flavor) + "Converting unknown KEY type %s to a plain KEY", flavor) pass index_d = {} index_d['name'] = spec['name'] @@ -2933,8 +3098,7 @@ class MySQLTableDefinitionParser(object): if not spec['full']: util.warn("Incomplete reflection of column definition %r" % line) - name, type_, args, notnull = \ - spec['name'], spec['coltype'], spec['arg'], spec['notnull'] + name, type_, args = spec['name'], spec['coltype'], spec['arg'] try: col_type = self.dialect.ischema_names[type_] @@ -2959,17 +3123,20 @@ class MySQLTableDefinitionParser(object): for kw in ('charset', 'collate'): if spec.get(kw, False): type_kw[kw] = spec[kw] - if issubclass(col_type, _EnumeratedValues): type_args = _EnumeratedValues._strip_values(type_args) + if issubclass(col_type, SET) and '' in type_args: + type_kw['retrieve_as_bitwise'] = True + type_instance = col_type(*type_args, **type_kw) - col_args, col_kw = [], {} + col_kw = {} # NOT NULL col_kw['nullable'] = True - if spec.get('notnull', False): + # this can be "NULL" in the case of TIMESTAMP + if spec.get('notnull', False) == 'NOT NULL': col_kw['nullable'] = False # AUTO_INCREMENT @@ -3088,7 +3255,7 @@ class MySQLTableDefinitionParser(object): r'(?: +(?P<zerofill>ZEROFILL))?' r'(?: +CHARACTER SET +(?P<charset>[\w_]+))?' r'(?: +COLLATE +(?P<collate>[\w_]+))?' - r'(?: +(?P<notnull>NOT NULL))?' + r'(?: +(?P<notnull>(?:NOT )?NULL))?' r'(?: +DEFAULT +(?P<default>' r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+' r'(?: +ON UPDATE \w+)?)' @@ -3108,7 +3275,7 @@ class MySQLTableDefinitionParser(object): r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' r'(?P<coltype>\w+)' r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' - r'.*?(?P<notnull>NOT NULL)?' + r'.*?(?P<notnull>(?:NOT )NULL)?' % quotes ) @@ -3215,9 +3382,17 @@ class _DecodingRowProxy(object): # sets.Set(['value']) (seriously) but thankfully that doesn't # seem to come up in DDL queries. + _encoding_compat = { + 'koi8r': 'koi8_r', + 'koi8u': 'koi8_u', + 'utf16': 'utf-16-be', # MySQL's uft16 is always bigendian + 'utf8mb4': 'utf8', # real utf8 + 'eucjpms': 'ujis', + } + def __init__(self, rowproxy, charset): self.rowproxy = rowproxy - self.charset = charset + self.charset = self._encoding_compat.get(charset, charset) def __getitem__(self, index): item = self.rowproxy[index] diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index 51b63044e..6d8466ab1 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -1,5 +1,5 @@ # mysql/cymysql.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mysql/gaerdbms.py b/lib/sqlalchemy/dialects/mysql/gaerdbms.py index 0059f5a65..58b70737f 100644 --- a/lib/sqlalchemy/dialects/mysql/gaerdbms.py +++ b/lib/sqlalchemy/dialects/mysql/gaerdbms.py @@ -1,5 +1,5 @@ # mysql/gaerdbms.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -17,6 +17,13 @@ developers-guide .. versionadded:: 0.7.8 + .. deprecated:: 1.0 This dialect is **no longer necessary** for + Google Cloud SQL; the MySQLdb dialect can be used directly. + Cloud SQL now recommends creating connections via the + mysql dialect using the URL format + + ``mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>`` + Pooling ------- @@ -33,6 +40,7 @@ import os from .mysqldb import MySQLDialect_mysqldb from ...pool import NullPool import re +from sqlalchemy.util import warn_deprecated def _is_dev_environment(): @@ -43,6 +51,14 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): @classmethod def dbapi(cls): + + warn_deprecated( + "Google Cloud SQL now recommends creating connections via the " + "MySQLdb dialect directly, using the URL format " + "mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/" + "<projectid>:<instancename>" + ) + # from django: # http://code.google.com/p/googleappengine/source/ # browse/trunk/python/google/storage/speckle/ diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 417e1ad6f..3a4eeec05 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -1,5 +1,5 @@ # mysql/mysqlconnector.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -14,6 +14,12 @@ :url: http://dev.mysql.com/downloads/connector/python/ +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + """ from .base import (MySQLDialect, MySQLExecutionContext, diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 73210d67a..4a7ba7e1d 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -1,5 +1,5 @@ # mysql/mysqldb.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -13,31 +13,30 @@ :connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname> :url: http://sourceforge.net/projects/mysql-python +.. _mysqldb_unicode: Unicode ------- -MySQLdb requires a "charset" parameter to be passed in order for it -to handle non-ASCII characters correctly. When this parameter is passed, -MySQLdb will also implicitly set the "use_unicode" flag to true, which means -that it will return Python unicode objects instead of bytestrings. -However, SQLAlchemy's decode process, when C extensions are enabled, -is orders of magnitude faster than that of MySQLdb as it does not call into -Python functions to do so. Therefore, the **recommended URL to use for -unicode** will include both charset and use_unicode=0:: +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. - create_engine("mysql+mysqldb://user:pass@host/dbname?charset=utf8&use_unicode=0") +Py3K Support +------------ -As of this writing, MySQLdb only runs on Python 2. It is not known how -MySQLdb behaves on Python 3 as far as unicode decoding. +Currently, MySQLdb only runs on Python 2 and development has been stopped. +`mysqlclient`_ is fork of MySQLdb and provides Python 3 support as well +as some bugfixes. +.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python -Known Issues -------------- +Using MySQLdb with Google Cloud SQL +----------------------------------- -MySQL-python version 1.2.2 has a serious memory leak related -to unicode conversion, a feature which is disabled via ``use_unicode=0``. -It is strongly advised to use the latest version of MySQL-Python. +Google Cloud SQL now recommends use of the MySQLdb dialect. Connect +using a URL like the following:: + + mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename> """ @@ -77,7 +76,7 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): class MySQLDialect_mysqldb(MySQLDialect): driver = 'mysqldb' - supports_unicode_statements = False + supports_unicode_statements = True supports_sane_rowcount = True supports_sane_multi_rowcount = True @@ -102,12 +101,13 @@ class MySQLDialect_mysqldb(MySQLDialect): # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 # specific issue w/ the utf8_bin collation and unicode returns - has_utf8_bin = connection.scalar( - "show collation where %s = 'utf8' and %s = 'utf8_bin'" - % ( - self.identifier_preparer.quote("Charset"), - self.identifier_preparer.quote("Collation") - )) + has_utf8_bin = self.server_version_info > (5, ) and \ + connection.scalar( + "show collation where %s = 'utf8' and %s = 'utf8_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation") + )) if has_utf8_bin: additional_tests = [ sql.collate(sql.cast( diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index fa127f3b0..ae8abc321 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -1,5 +1,5 @@ # mysql/oursql.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -16,22 +16,10 @@ Unicode ------- -oursql defaults to using ``utf8`` as the connection charset, but other -encodings may be used instead. Like the MySQL-Python driver, unicode support -can be completely disabled:: +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. - # oursql sets the connection charset to utf8 automatically; all strings come - # back as utf8 str - create_engine('mysql+oursql:///mydb?use_unicode=0') -To not automatically use ``utf8`` and instead use whatever the connection -defaults to, there is a separate parameter:: - - # use the default connection charset; all strings come back as unicode - create_engine('mysql+oursql:///mydb?default_charset=1') - - # use latin1 as the connection charset; all strings come back as unicode - create_engine('mysql+oursql:///mydb?charset=latin1') """ import re diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 31226cea0..87159b561 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -1,5 +1,5 @@ # mysql/pymysql.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -12,7 +12,13 @@ :dbapi: pymysql :connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>\ [?<options>] - :url: http://code.google.com/p/pymysql/ + :url: http://www.pymysql.org/ + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. MySQL-Python Compatibility -------------------------- @@ -31,8 +37,12 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): driver = 'pymysql' description_encoding = None - if py3k: - supports_unicode_statements = True + + # generally, these two values should be both True + # or both False. PyMySQL unicode tests pass all the way back + # to 0.4 either way. See [ticket:3337] + supports_unicode_statements = True + supports_unicode_binds = True @classmethod def dbapi(cls): diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 58e8b30fe..b544f0584 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -1,5 +1,5 @@ # mysql/pyodbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -14,14 +14,11 @@ :connectstring: mysql+pyodbc://<username>:<password>@<dsnname> :url: http://pypi.python.org/pypi/pyodbc/ - -Limitations ------------ - -The mysql-pyodbc dialect is subject to unresolved character encoding issues -which exist within the current ODBC drivers available. -(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage -of OurSQL, MySQLdb, or MySQL-connector/Python. + .. note:: The PyODBC for MySQL dialect is not well supported, and + is subject to unresolved character encoding issues + which exist within the current ODBC drivers available. + (see http://code.google.com/p/pyodbc/issues/detail?id=25). + Other dialects for MySQL are recommended. """ diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py index 0cf92cd13..37b0b6309 100644 --- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -1,5 +1,5 @@ # mysql/zxjdbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -14,6 +14,9 @@ <database> :driverurl: http://dev.mysql.com/downloads/connector/j/ + .. note:: Jython is not supported by current versions of SQLAlchemy. The + zxjdbc dialect should be considered as experimental. + Character Sets -------------- diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index fd32f2235..b055b0b16 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -1,5 +1,5 @@ # oracle/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 6df38e57e..c605bd510 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1,5 +1,5 @@ # oracle/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -213,6 +213,8 @@ is reflected and the type is reported as ``DATE``, the time-supporting examining the type of column for use in special Python translations or for migrating schemas to other database backends. +.. _oracle_table_options: + Oracle Table Options ------------------------- @@ -228,15 +230,63 @@ in conjunction with the :class:`.Table` construct: .. versionadded:: 1.0.0 +* ``COMPRESS``:: + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=True) + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=6) + + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. + +.. versionadded:: 1.0.0 + +.. _oracle_index_options: + +Oracle Specific Index Options +----------------------------- + +Bitmap Indexes +~~~~~~~~~~~~~~ + +You can specify the ``oracle_bitmap`` parameter to create a bitmap index +instead of a B-tree index:: + + Index('my_index', my_table.c.data, oracle_bitmap=True) + +Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not +check for such limitations, only the database will. + +.. versionadded:: 1.0.0 + +Index compression +~~~~~~~~~~~~~~~~~ + +Oracle has a more efficient storage mode for indexes containing lots of +repeated values. Use the ``oracle_compress`` parameter to turn on key c +ompression:: + + Index('my_index', my_table.c.data, oracle_compress=True) + + Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, + oracle_compress=1) + +The ``oracle_compress`` parameter accepts either an integer specifying the +number of prefix columns to compress, or ``True`` to use the default (all +columns for non-unique indexes, all but the last column for unique indexes). + +.. versionadded:: 1.0.0 + """ import re from sqlalchemy import util, sql -from sqlalchemy.engine import default, base, reflection +from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler, visitors, expression -from sqlalchemy.sql import (operators as sql_operators, - functions as sql_functions) +from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes, schema as sa_schema from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ BLOB, CLOB, TIMESTAMP, FLOAT @@ -407,19 +457,19 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): # Oracle does not allow milliseconds in DATE # Oracle does not support TIME columns - def visit_datetime(self, type_): - return self.visit_DATE(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_unicode(self, type_): + def visit_unicode(self, type_, **kw): if self.dialect._supports_nchar: - return self.visit_NVARCHAR2(type_) + return self.visit_NVARCHAR2(type_, **kw) else: - return self.visit_VARCHAR2(type_) + return self.visit_VARCHAR2(type_, **kw) - def visit_INTERVAL(self, type_): + def visit_INTERVAL(self, type_, **kw): return "INTERVAL DAY%s TO SECOND%s" % ( type_.day_precision is not None and "(%d)" % type_.day_precision or @@ -429,22 +479,22 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): "", ) - def visit_LONG(self, type_): + def visit_LONG(self, type_, **kw): return "LONG" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): if type_.timezone: return "TIMESTAMP WITH TIME ZONE" else: return "TIMESTAMP" - def visit_DOUBLE_PRECISION(self, type_): - return self._generate_numeric(type_, "DOUBLE PRECISION") + def visit_DOUBLE_PRECISION(self, type_, **kw): + return self._generate_numeric(type_, "DOUBLE PRECISION", **kw) def visit_NUMBER(self, type_, **kw): return self._generate_numeric(type_, "NUMBER", **kw) - def _generate_numeric(self, type_, name, precision=None, scale=None): + def _generate_numeric(self, type_, name, precision=None, scale=None, **kw): if precision is None: precision = type_.precision @@ -460,17 +510,17 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): n = "%(name)s(%(precision)s, %(scale)s)" return n % {'name': name, 'precision': precision, 'scale': scale} - def visit_string(self, type_): - return self.visit_VARCHAR2(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR2(type_, **kw) - def visit_VARCHAR2(self, type_): + def visit_VARCHAR2(self, type_, **kw): return self._visit_varchar(type_, '', '2') - def visit_NVARCHAR2(self, type_): + def visit_NVARCHAR2(self, type_, **kw): return self._visit_varchar(type_, 'N', '2') visit_NVARCHAR = visit_NVARCHAR2 - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._visit_varchar(type_, '', '') def _visit_varchar(self, type_, n, num): @@ -483,31 +533,31 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): varchar = "%(n)sVARCHAR%(two)s(%(length)s)" return varchar % {'length': type_.length, 'two': num, 'n': n} - def visit_text(self, type_): - return self.visit_CLOB(type_) + def visit_text(self, type_, **kw): + return self.visit_CLOB(type_, **kw) - def visit_unicode_text(self, type_): + def visit_unicode_text(self, type_, **kw): if self.dialect._supports_nchar: - return self.visit_NCLOB(type_) + return self.visit_NCLOB(type_, **kw) else: - return self.visit_CLOB(type_) + return self.visit_CLOB(type_, **kw) - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_NUMBER(type_, precision=19) + def visit_big_integer(self, type_, **kw): + return self.visit_NUMBER(type_, precision=19, **kw) - def visit_boolean(self, type_): - return self.visit_SMALLINT(type_) + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_RAW(self, type_): + def visit_RAW(self, type_, **kw): if type_.length: return "RAW(%(length)s)" % {'length': type_.length} else: return "RAW" - def visit_ROWID(self, type_): + def visit_ROWID(self, type_, **kw): return "ROWID" @@ -549,6 +599,9 @@ class OracleCompiler(compiler.SQLCompiler): def visit_false(self, expr, **kw): return '0' + def get_cte_preamble(self, recursive): + return "WITH" + def get_select_hint_text(self, byfroms): return " ".join( "/*+ %s */" % text for table, text in byfroms.items() @@ -612,29 +665,17 @@ class OracleCompiler(compiler.SQLCompiler): else: return sql.and_(*clauses) - def visit_outer_join_column(self, vc): - return self.process(vc.column) + "(+)" + def visit_outer_join_column(self, vc, **kw): + return self.process(vc.column, **kw) + "(+)" def visit_sequence(self, seq): return (self.dialect.identifier_preparer.format_sequence(seq) + ".nextval") - def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs): - """Oracle doesn't like ``FROM table AS alias``. Is the AS standard - SQL?? - """ - - if asfrom or ashint: - alias_name = isinstance(alias.name, expression._truncated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name + def get_render_as_alias_suffix(self, alias_name_text): + """Oracle doesn't like ``FROM table AS alias``""" - if ashint: - return alias_name - elif asfrom: - return self.process(alias.original, asfrom=asfrom, **kwargs) + \ - " " + self.preparer.format_alias(alias, alias_name) - else: - return self.process(alias.original, **kwargs) + return " " + alias_name_text def returning_clause(self, stmt, returning_cols): columns = [] @@ -651,8 +692,9 @@ class OracleCompiler(compiler.SQLCompiler): self.bindparam_string(self._truncate_bindparam(outparam))) columns.append( self.process(col_expr, within_columns_clause=False)) - self.result_map[outparam.key] = ( - outparam.key, + + self._add_to_result_map( + outparam.key, outparam.key, (column, getattr(column, 'name', None), getattr(column, 'key', None)), column.type @@ -695,7 +737,7 @@ class OracleCompiler(compiler.SQLCompiler): # Outer select and "ROWNUM as ora_rn" can be dropped if # limit=0 - # TODO: use annotations instead of clone + attr set ? + kwargs['select_wraps_for'] = select select = select._generate() select._oracle_visit = True @@ -752,7 +794,6 @@ class OracleCompiler(compiler.SQLCompiler): offsetselect._for_update_arg = select._for_update_arg select = offsetselect - kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) return compiler.SQLCompiler.visit_select(self, select, **kwargs) def limit_clause(self, select, **kw): @@ -795,9 +836,32 @@ class OracleDDLCompiler(compiler.DDLCompiler): return text - def visit_create_index(self, create, **kw): - return super(OracleDDLCompiler, self).\ - visit_create_index(create, include_schema=True) + def visit_create_index(self, create): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.dialect_options['oracle']['bitmap']: + text += "BITMAP " + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=True), + ', '.join( + self.sql_compiler.process( + expr, + include_table=False, literal_binds=True) + for expr in index.expressions) + ) + if index.dialect_options['oracle']['compress'] is not False: + if index.dialect_options['oracle']['compress'] is True: + text += " COMPRESS" + else: + text += " COMPRESS %d" % ( + index.dialect_options['oracle']['compress'] + ) + return text def post_create_table(self, table): table_opts = [] @@ -807,6 +871,14 @@ class OracleDDLCompiler(compiler.DDLCompiler): on_commit_options = opts['on_commit'].replace("_", " ").upper() table_opts.append('\n ON COMMIT %s' % on_commit_options) + if opts['compress']: + if opts['compress'] is True: + table_opts.append("\n COMPRESS") + else: + table_opts.append("\n COMPRESS FOR %s" % ( + opts['compress'] + )) + return ''.join(table_opts) @@ -847,6 +919,8 @@ class OracleDialect(default.DefaultDialect): supports_sane_rowcount = True supports_sane_multi_rowcount = False + supports_simple_order_by_label = False + supports_sequences = True sequences_optional = False postfetch_lastrowid = False @@ -870,7 +944,12 @@ class OracleDialect(default.DefaultDialect): construct_arguments = [ (sa_schema.Table, { "resolve_synonyms": False, - "on_commit": None + "on_commit": None, + "compress": False + }), + (sa_schema.Index, { + "bitmap": False, + "compress": False }) ] @@ -902,6 +981,16 @@ class OracleDialect(default.DefaultDialect): self.server_version_info < (9, ) @property + def _supports_table_compression(self): + return self.server_version_info and \ + self.server_version_info >= (9, 2, ) + + @property + def _supports_table_compress_for(self): + return self.server_version_info and \ + self.server_version_info >= (11, ) + + @property def _supports_char_length(self): return not self._is_oracle_8 @@ -1084,6 +1173,50 @@ class OracleDialect(default.DefaultDialect): return [self.normalize_name(row[0]) for row in cursor] @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + options = {} + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + params = {"table_name": table_name} + + columns = ["table_name"] + if self._supports_table_compression: + columns.append("compression") + if self._supports_table_compress_for: + columns.append("compress_for") + + text = "SELECT %(columns)s "\ + "FROM ALL_TABLES%(dblink)s "\ + "WHERE table_name = :table_name" + + if schema is not None: + params['owner'] = schema + text += " AND owner = :owner " + text = text % {'dblink': dblink, 'columns': ", ".join(columns)} + + result = connection.execute(sql.text(text), **params) + + enabled = dict(DISABLED=False, ENABLED=True) + + row = result.first() + if row: + if "compression" in row and enabled.get(row.compression, False): + if "compress_for" in row: + options['oracle_compress'] = row.compress_for + else: + options['oracle_compress'] = True + + return options + + @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): """ @@ -1168,7 +1301,8 @@ class OracleDialect(default.DefaultDialect): params = {'table_name': table_name} text = \ - "SELECT a.index_name, a.column_name, b.uniqueness "\ + "SELECT a.index_name, a.column_name, "\ + "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "\ "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\ "\nALL_INDEXES%(dblink)s b "\ "\nWHERE "\ @@ -1194,6 +1328,7 @@ class OracleDialect(default.DefaultDialect): dblink=dblink, info_cache=kw.get('info_cache')) pkeys = pk_constraint['constrained_columns'] uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + enabled = dict(DISABLED=False, ENABLED=True) oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) @@ -1213,10 +1348,15 @@ class OracleDialect(default.DefaultDialect): if rset.index_name != last_index_name: remove_if_primary_key(index) index = dict(name=self.normalize_name(rset.index_name), - column_names=[]) + column_names=[], dialect_options={}) indexes.append(index) index['unique'] = uniqueness.get(rset.uniqueness, False) + if rset.index_type in ('BITMAP', 'FUNCTION-BASED BITMAP'): + index['dialect_options']['oracle_bitmap'] = True + if enabled.get(rset.compression, False): + index['dialect_options']['oracle_compress'] = rset.prefix_length + # filter out Oracle SYS_NC names. could also do an outer join # to the all_tab_columns table and check for real col names there. if not oracle_sys_col.match(rset.column_name): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 4a1ceecb1..4aed45c14 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -1,5 +1,5 @@ # oracle/cx_oracle.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -61,6 +61,14 @@ on the URL, or as keyword arguments to :func:`.create_engine()` are: Defaults to ``True``. Note that this is the opposite default of the cx_Oracle DBAPI itself. +* ``service_name`` - An option to use connection string (DSN) with + ``SERVICE_NAME`` instead of ``SID``. It can't be passed when a ``database`` + part is given. + E.g. ``oracle+cx_oracle://scott:tiger@host:1521/?service_name=hr`` + is a valid url. This value is only available as a URL query string argument. + + .. versionadded:: 1.0.0 + .. _cx_oracle_unicode: Unicode @@ -862,14 +870,26 @@ class OracleDialect_cx_oracle(OracleDialect): util.coerce_kw_type(dialect_opts, opt, bool) setattr(self, opt, dialect_opts[opt]) - if url.database: + database = url.database + service_name = dialect_opts.get('service_name', None) + if database or service_name: # if we have a database, then we have a remote host port = url.port if port: port = int(port) else: port = 1521 - dsn = self.dbapi.makedsn(url.host, port, url.database) + + if database and service_name: + raise exc.InvalidRequestError( + '"service_name" option shouldn\'t ' + 'be used with a "database" part of the url') + if database: + makedsn_kwargs = {'sid': database} + if service_name: + makedsn_kwargs = {'service_name': service_name} + + dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs) else: # we have a local tnsname dsn = url.host diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py index 82c8e2f0f..ab1ade047 100644 --- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -1,5 +1,5 @@ # oracle/zxjdbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -13,6 +13,9 @@ :driverurl: http://www.oracle.com/technology/software/tech/java/\ sqlj_jdbc/index.html. + .. note:: Jython is not supported by current versions of SQLAlchemy. The + zxjdbc dialect should be considered as experimental. + """ import decimal import re diff --git a/lib/sqlalchemy/dialects/postgres.py b/lib/sqlalchemy/dialects/postgres.py index f813e0003..3335333e5 100644 --- a/lib/sqlalchemy/dialects/postgres.py +++ b/lib/sqlalchemy/dialects/postgres.py @@ -1,5 +1,5 @@ # dialects/postgres.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 1cff8e3a0..98fe6f085 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -1,11 +1,11 @@ # postgresql/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from . import base, psycopg2, pg8000, pypostgresql, zxjdbc +from . import base, psycopg2, pg8000, pypostgresql, zxjdbc, psycopg2cffi base.dialect = psycopg2.dialect diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index baa640eaa..c1c0ab08e 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1,5 +1,5 @@ # postgresql/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -48,7 +48,7 @@ Transaction Isolation Level --------------------------- All Postgresql dialects support setting of transaction isolation level -both via a dialect-specific parameter ``isolation_level`` +both via a dialect-specific parameter :paramref:`.create_engine.isolation_level` accepted by :func:`.create_engine`, as well as the ``isolation_level`` argument as passed to :meth:`.Connection.execution_options`. When using a non-psycopg2 dialect, @@ -266,7 +266,7 @@ will emit to the database:: The Postgresql text search functions such as ``to_tsquery()`` and ``to_tsvector()`` are available -explicitly using the standard :attr:`.func` construct. For example:: +explicitly using the standard :data:`.func` construct. For example:: select([ func.to_tsvector('fat cats ate rats').match('cat & rat') @@ -299,7 +299,7 @@ not re-compute the column on demand. In order to provide for this explicit query planning, or to use different search strategies, the ``match`` method accepts a ``postgresql_regconfig`` -keyword argument. +keyword argument:: select([mytable.c.id]).where( mytable.c.title.match('somestring', postgresql_regconfig='english') @@ -311,7 +311,7 @@ Emits the equivalent of:: WHERE mytable.title @@ to_tsquery('english', 'somestring') One can also specifically pass in a `'regconfig'` value to the -``to_tsvector()`` command as the initial argument. +``to_tsvector()`` command as the initial argument:: select([mytable.c.id]).where( func.to_tsvector('english', mytable.c.title )\ @@ -402,6 +402,24 @@ underlying CREATE INDEX command, so it *must* be a valid index type for your version of PostgreSQL. +.. _postgresql_index_concurrently: + +Indexes with CONCURRENTLY +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Postgresql index option CONCURRENTLY is supported by passing the +flag ``postgresql_concurrently`` to the :class:`.Index` construct:: + + tbl = Table('testtbl', m, Column('data', Integer)) + + idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + +The above index construct will render SQL as:: + + CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data) + +.. versionadded:: 0.9.9 + .. _postgresql_index_reflection: Postgresql Index Reflection @@ -477,13 +495,29 @@ dialect in conjunction with the :class:`.Table` construct: `Postgresql CREATE TABLE options <http://www.postgresql.org/docs/9.3/static/sql-createtable.html>`_ +ENUM Types +---------- + +Postgresql has an independently creatable TYPE structure which is used +to implement an enumerated type. This approach introduces significant +complexity on the SQLAlchemy side in terms of when this type should be +CREATED and DROPPED. The type object is also an independently reflectable +entity. The following sections should be consulted: + +* :class:`.postgresql.ENUM` - DDL and typing support for ENUM. + +* :meth:`.PGInspector.get_enums` - retrieve a listing of current ENUM types + +* :meth:`.postgresql.ENUM.create` , :meth:`.postgresql.ENUM.drop` - individual + CREATE and DROP commands for ENUM. + """ from collections import defaultdict import re from ... import sql, schema, exc, util from ...engine import default, reflection -from ...sql import compiler, expression, operators +from ...sql import compiler, expression, operators, default_comparator from ... import types as sqltypes try: @@ -680,10 +714,10 @@ class _Slice(expression.ColumnElement): type = sqltypes.NULLTYPE def __init__(self, slice_, source_comparator): - self.start = source_comparator._check_literal( + self.start = default_comparator._check_literal( source_comparator.expr, operators.getitem, slice_.start) - self.stop = source_comparator._check_literal( + self.stop = default_comparator._check_literal( source_comparator.expr, operators.getitem, slice_.stop) @@ -876,8 +910,9 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): index += shift_indexes return_type = self.type.item_type - return self._binary_operate(self.expr, operators.getitem, index, - result_type=return_type) + return default_comparator._binary_operate( + self.expr, operators.getitem, index, + result_type=return_type) def any(self, other, operator=operators.eq): """Return ``other operator ANY (array)`` clause. @@ -1080,21 +1115,76 @@ class ENUM(sqltypes.Enum): """Postgresql ENUM type. This is a subclass of :class:`.types.Enum` which includes - support for PG's ``CREATE TYPE``. - - :class:`~.postgresql.ENUM` is used automatically when - using the :class:`.types.Enum` type on PG assuming - the ``native_enum`` is left as ``True``. However, the - :class:`~.postgresql.ENUM` class can also be instantiated - directly in order to access some additional Postgresql-specific - options, namely finer control over whether or not - ``CREATE TYPE`` should be emitted. - - Note that both :class:`.types.Enum` as well as - :class:`~.postgresql.ENUM` feature create/drop - methods; the base :class:`.types.Enum` type ultimately - delegates to the :meth:`~.postgresql.ENUM.create` and - :meth:`~.postgresql.ENUM.drop` methods present here. + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`.types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the Postgresql backend will use a :class:`.postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`.types.Enum` or :class:`.postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`.Table.create` and :meth:`.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`.types.Enum` or + :class:`.postgresql.ENUM` independently, and associate it with the + :class:`.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + .. versionchanged:: 1.0.0 The Postgresql :class:`.postgresql.ENUM` type + now behaves more strictly with regards to CREATE/DROP. A metadata-level + ENUM type will only be created and dropped at the metadata level, + not the table level, with the exception of + ``table.create(checkfirst=True)``. + The ``table.drop()`` call will now emit a DROP TYPE for a table-level + enumerated type. """ @@ -1200,9 +1290,18 @@ class ENUM(sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): + if checkfirst or ( + not self.metadata and + not kw.get('_is_metadata_operation', False)) and \ + not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) + def _on_table_drop(self, target, bind, checkfirst, **kw): + if not self.metadata and \ + not kw.get('_is_metadata_operation', False) and \ + not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + def _on_metadata_create(self, target, bind, checkfirst, **kw): if not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) @@ -1424,7 +1523,8 @@ class PGDDLCompiler(compiler.DDLCompiler): else: colspec += " SERIAL" else: - colspec += " " + self.dialect.type_compiler.process(column.type) + colspec += " " + self.dialect.type_compiler.process(column.type, + type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -1457,7 +1557,13 @@ class PGDDLCompiler(compiler.DDLCompiler): text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s " % ( + text += "INDEX " + + concurrently = index.dialect_options['postgresql']['concurrently'] + if concurrently: + text += "CONCURRENTLY " + + text += "%s ON %s " % ( self._prepared_index_name(index, include_schema=False), preparer.format_table(index.table) @@ -1476,8 +1582,13 @@ class PGDDLCompiler(compiler.DDLCompiler): if not isinstance(expr, expression.ColumnClause) else expr, include_table=False, literal_binds=True) + - (c.key in ops and (' ' + ops[c.key]) or '') - for expr, c in zip(index.expressions, index.columns)]) + ( + (' ' + ops[expr.key]) + if hasattr(expr, 'key') + and expr.key in ops else '' + ) + for expr in index.expressions + ]) ) whereclause = index.dialect_options["postgresql"]["where"] @@ -1539,94 +1650,93 @@ class PGDDLCompiler(compiler.DDLCompiler): class PGTypeCompiler(compiler.GenericTypeCompiler): - - def visit_TSVECTOR(self, type): + def visit_TSVECTOR(self, type, **kw): return "TSVECTOR" - def visit_INET(self, type_): + def visit_INET(self, type_, **kw): return "INET" - def visit_CIDR(self, type_): + def visit_CIDR(self, type_, **kw): return "CIDR" - def visit_MACADDR(self, type_): + def visit_MACADDR(self, type_, **kw): return "MACADDR" - def visit_OID(self, type_): + def visit_OID(self, type_, **kw): return "OID" - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): if not type_.precision: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': type_.precision} - def visit_DOUBLE_PRECISION(self, type_): + def visit_DOUBLE_PRECISION(self, type_, **kw): return "DOUBLE PRECISION" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_HSTORE(self, type_): + def visit_HSTORE(self, type_, **kw): return "HSTORE" - def visit_JSON(self, type_): + def visit_JSON(self, type_, **kw): return "JSON" - def visit_JSONB(self, type_): + def visit_JSONB(self, type_, **kw): return "JSONB" - def visit_INT4RANGE(self, type_): + def visit_INT4RANGE(self, type_, **kw): return "INT4RANGE" - def visit_INT8RANGE(self, type_): + def visit_INT8RANGE(self, type_, **kw): return "INT8RANGE" - def visit_NUMRANGE(self, type_): + def visit_NUMRANGE(self, type_, **kw): return "NUMRANGE" - def visit_DATERANGE(self, type_): + def visit_DATERANGE(self, type_, **kw): return "DATERANGE" - def visit_TSRANGE(self, type_): + def visit_TSRANGE(self, type_, **kw): return "TSRANGE" - def visit_TSTZRANGE(self, type_): + def visit_TSTZRANGE(self, type_, **kw): return "TSTZRANGE" - def visit_datetime(self, type_): - return self.visit_TIMESTAMP(type_) + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) - def visit_enum(self, type_): + def visit_enum(self, type_, **kw): if not type_.native_enum or not self.dialect.supports_native_enum: - return super(PGTypeCompiler, self).visit_enum(type_) + return super(PGTypeCompiler, self).visit_enum(type_, **kw) else: - return self.visit_ENUM(type_) + return self.visit_ENUM(type_, **kw) - def visit_ENUM(self, type_): + def visit_ENUM(self, type_, **kw): return self.dialect.identifier_preparer.format_type(type_) - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_INTERVAL(self, type_): + def visit_INTERVAL(self, type_, **kw): if type_.precision is not None: return "INTERVAL(%d)" % type_.precision else: return "INTERVAL" - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): if type_.varying: compiled = "BIT VARYING" if type_.length is not None: @@ -1635,16 +1745,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): compiled = "BIT(%d)" % type_.length return compiled - def visit_UUID(self, type_): + def visit_UUID(self, type_, **kw): return "UUID" - def visit_large_binary(self, type_): - return self.visit_BYTEA(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BYTEA(type_, **kw) - def visit_BYTEA(self, type_): + def visit_BYTEA(self, type_, **kw): return "BYTEA" - def visit_ARRAY(self, type_): + def visit_ARRAY(self, type_, **kw): return self.process(type_.item_type) + ('[]' * (type_.dimensions if type_.dimensions is not None else 1)) @@ -1806,7 +1916,8 @@ class PGDialect(default.DefaultDialect): (schema.Index, { "using": False, "where": None, - "ops": {} + "ops": {}, + "concurrently": False, }), (schema.Table, { "ignore_search_path": False, @@ -1942,7 +2053,8 @@ class PGDialect(default.DefaultDialect): cursor = connection.execute( sql.text( "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=current_schema() " + "n.oid=c.relnamespace where " + "pg_catalog.pg_table_is_visible(c.oid) " "and relname=:name", bindparams=[ sql.bindparam('name', util.text_type(table_name), @@ -2489,37 +2601,59 @@ class PGDialect(default.DefaultDialect): # cast indkey as varchar since it's an int2vector, # returned as a list by some drivers such as pypostgresql - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, ix.indpred, - a.attname, a.attnum, c.conrelid, ix.indkey%s - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and %s - left outer join - pg_constraint c - on (ix.indrelid = c.conrelid and - ix.indexrelid = c.conindid and - c.contype in ('p', 'u', 'x')) - WHERE - t.relkind IN ('r', 'v', 'f', 'm') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - # version 8.3 here was based on observing the - # cast does not work in PG 8.2.4, does work in 8.3.0. - # nothing in PG changelogs regarding this. - "::varchar" if self.server_version_info >= (8, 3) else "", - self._pg_index_any("a.attnum", "ix.indkey") - ) + if self.server_version_info < (8, 5): + IDX_SQL = """ + SELECT + i.relname as relname, + ix.indisunique, ix.indexprs, ix.indpred, + a.attname, a.attnum, NULL, ix.indkey%s + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_class i on i.oid = ix.indexrelid + left outer join + pg_attribute a + on t.oid = a.attrelid and %s + WHERE + t.relkind IN ('r', 'v', 'f', 'm') + and t.oid = :table_oid + and ix.indisprimary = 'f' + ORDER BY + t.relname, + i.relname + """ % ( + # version 8.3 here was based on observing the + # cast does not work in PG 8.2.4, does work in 8.3.0. + # nothing in PG changelogs regarding this. + "::varchar" if self.server_version_info >= (8, 3) else "", + self._pg_index_any("a.attnum", "ix.indkey") + ) + else: + IDX_SQL = """ + SELECT + i.relname as relname, + ix.indisunique, ix.indexprs, ix.indpred, + a.attname, a.attnum, c.conrelid, ix.indkey::varchar + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_class i on i.oid = ix.indexrelid + left outer join + pg_attribute a + on t.oid = a.attrelid and a.attnum = ANY(ix.indkey) + left outer join + pg_constraint c + on (ix.indrelid = c.conrelid and + ix.indexrelid = c.conindid and + c.contype in ('p', 'u', 'x')) + WHERE + t.relkind IN ('r', 'v', 'f', 'm') + and t.oid = :table_oid + and ix.indisprimary = 'f' + ORDER BY + t.relname, + i.relname + """ t = sql.text(IDX_SQL, typemap={'attname': sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) diff --git a/lib/sqlalchemy/dialects/postgresql/constraints.py b/lib/sqlalchemy/dialects/postgresql/constraints.py index e8ebc75dd..0371daf3d 100644 --- a/lib/sqlalchemy/dialects/postgresql/constraints.py +++ b/lib/sqlalchemy/dialects/postgresql/constraints.py @@ -1,4 +1,4 @@ -# Copyright (C) 2013-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2013-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 9601edc41..9f369cb5b 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -1,5 +1,5 @@ # postgresql/hstore.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 250bf5e9d..13ebc4afe 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -1,5 +1,5 @@ # postgresql/json.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -12,7 +12,7 @@ from .base import ischema_names from ... import types as sqltypes from ...sql.operators import custom_op from ... import sql -from ...sql import elements +from ...sql import elements, default_comparator from ... import util __all__ = ('JSON', 'JSONElement', 'JSONB') @@ -46,7 +46,8 @@ class JSONElement(elements.BinaryExpression): self._json_opstring = opstring operator = custom_op(opstring, precedence=5) - right = left._check_literal(left, operator, right) + right = default_comparator._check_literal( + left, operator, right) super(JSONElement, self).__init__( left, right, operator, type_=result_type) @@ -77,7 +78,7 @@ class JSONElement(elements.BinaryExpression): def cast(self, type_): """Convert this :class:`.JSONElement` to apply both the 'astext' operator - as well as an explicit type cast when evaulated. + as well as an explicit type cast when evaluated. E.g.:: diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 4ccc90208..c71f689a3 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -1,5 +1,5 @@ # postgresql/pg8000.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors <see AUTHORS # file> # # This module is part of SQLAlchemy and is released under @@ -13,17 +13,30 @@ postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] :url: https://pythonhosted.org/pg8000/ + +.. _pg8000_unicode: + Unicode ------- -When communicating with the server, pg8000 **always uses the server-side -character set**. SQLAlchemy has no ability to modify what character set -pg8000 chooses to use, and additionally SQLAlchemy does no unicode conversion -of any kind with the pg8000 backend. The origin of the client encoding setting -is ultimately the CLIENT_ENCODING setting in postgresql.conf. +pg8000 will encode / decode string values between it and the server using the +PostgreSQL ``client_encoding`` parameter; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf-8``, as a more useful default:: + + #client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +The ``client_encoding`` can be overriden for a session by executing the SQL: + +SET CLIENT_ENCODING TO 'utf8'; + +SQLAlchemy will execute this SQL on all new connections based on the value +passed to :func:`.create_engine` using the ``client_encoding`` parameter:: -It is not necessary, though is also harmless, to pass the "encoding" parameter -to :func:`.create_engine` when using pg8000. + engine = create_engine( + "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') .. _pg8000_isolation_level: @@ -58,6 +71,8 @@ from ... import types as sqltypes from .base import ( PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext, _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES) +import re +from sqlalchemy.dialects.postgresql.json import JSON class _PGNumeric(sqltypes.Numeric): @@ -88,6 +103,15 @@ class _PGNumericNoBind(_PGNumeric): return None +class _PGJSON(JSON): + + def result_processor(self, dialect, coltype): + if dialect._dbapi_version > (1, 10, 1): + return None # Has native JSON + else: + return super(_PGJSON, self).result_processor(dialect, coltype) + + class PGExecutionContext_pg8000(PGExecutionContext): pass @@ -129,20 +153,29 @@ class PGDialect_pg8000(PGDialect): PGDialect.colspecs, { sqltypes.Numeric: _PGNumericNoBind, - sqltypes.Float: _PGNumeric + sqltypes.Float: _PGNumeric, + JSON: _PGJSON, } ) + def __init__(self, client_encoding=None, **kwargs): + PGDialect.__init__(self, **kwargs) + self.client_encoding = client_encoding + def initialize(self, connection): - if self.dbapi and hasattr(self.dbapi, '__version__'): - self._dbapi_version = tuple([ - int(x) for x in - self.dbapi.__version__.split(".")]) - else: - self._dbapi_version = (99, 99, 99) self.supports_sane_multi_rowcount = self._dbapi_version >= (1, 9, 14) super(PGDialect_pg8000, self).initialize(connection) + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, '__version__'): + return tuple( + [ + int(x) for x in re.findall( + r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)]) + else: + return (99, 99, 99) + @classmethod def dbapi(cls): return __import__('pg8000') @@ -181,6 +214,16 @@ class PGDialect_pg8000(PGDialect): (level, self.name, ", ".join(self._isolation_lookup)) ) + def set_client_encoding(self, connection, client_encoding): + # adjust for ConnectionFairy possibly being present + if hasattr(connection, 'connection'): + connection = connection.connection + + cursor = connection.cursor() + cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'") + cursor.execute("COMMIT") + cursor.close() + def do_begin_twophase(self, connection, xid): connection.connection.tpc_begin((0, xid, '')) @@ -198,4 +241,24 @@ class PGDialect_pg8000(PGDialect): def do_recover_twophase(self, connection): return [row[1] for row in connection.connection.tpc_recover()] + def on_connect(self): + fns = [] + if self.client_encoding is not None: + def on_connect(conn): + self.set_client_encoding(conn, self.client_encoding) + fns.append(on_connect) + + if self.isolation_level is not None: + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) + + if len(fns) > 0: + def on_connect(conn): + for fn in fns: + fn(conn) + return on_connect + else: + return None + dialect = PGDialect_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 1a2a1ffe4..46228ac15 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -1,5 +1,5 @@ # postgresql/psycopg2.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -66,12 +66,13 @@ in ``/tmp``, or whatever socket directory was specified when PostgreSQL was built. This value can be overridden by passing a pathname to psycopg2, using ``host`` as an additional keyword argument:: - create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + create_engine("postgresql+psycopg2://user:password@/dbname?\ +host=/var/lib/postgresql") See also: -`PQconnectdbParams <http://www.postgresql.org/docs/9.1/static\ -/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_ +`PQconnectdbParams <http://www.postgresql.org/docs/9.1/static/\ +libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_ Per-Statement/Connection Execution Options ------------------------------------------- @@ -159,6 +160,55 @@ defaults to ``utf-8``. SQLAlchemy's own unicode encode/decode functionality is steadily becoming obsolete as most DBAPIs now support unicode fully. +Bound Parameter Styles +---------------------- + +The default parameter style for the psycopg2 dialect is "pyformat", where +SQL is rendered using ``%(paramname)s`` style. This format has the limitation +that it does not accommodate the unusual case of parameter names that +actually contain percent or parenthesis symbols; as SQLAlchemy in many cases +generates bound parameter names based on the name of a column, the presence +of these characters in a column name can lead to problems. + +There are two solutions to the issue of a :class:`.schema.Column` that contains +one of these characters in its name. One is to specify the +:paramref:`.schema.Column.key` for columns that have such names:: + + measurement = Table('measurement', metadata, + Column('Size (meters)', Integer, key='size_meters') + ) + +Above, an INSERT statement such as ``measurement.insert()`` will use +``size_meters`` as the parameter name, and a SQL expression such as +``measurement.c.size_meters > 10`` will derive the bound parameter name +from the ``size_meters`` key as well. + +.. versionchanged:: 1.0.0 - SQL expressions will use :attr:`.Column.key` + as the source of naming when anonymous bound parameters are created + in SQL expressions; previously, this behavior only applied to + :meth:`.Table.insert` and :meth:`.Table.update` parameter names. + +The other solution is to use a positional format; psycopg2 allows use of the +"format" paramstyle, which can be passed to +:paramref:`.create_engine.paramstyle`:: + + engine = create_engine( + 'postgresql://scott:tiger@localhost:5432/test', paramstyle='format') + +With the above engine, instead of a statement like:: + + INSERT INTO measurement ("Size (meters)") VALUES (%(Size (meters))s) + {'Size (meters)': 1} + +we instead see:: + + INSERT INTO measurement ("Size (meters)") VALUES (%s) + (1, ) + +Where above, the dictionary style is converted into a tuple with positional +style. + + Transactions ------------ @@ -188,7 +238,7 @@ The psycopg2 dialect supports these constants for isolation level: * ``AUTOCOMMIT`` .. versionadded:: 0.8.2 support for AUTOCOMMIT isolation level when using - psycopg2. + psycopg2. .. seealso:: @@ -213,14 +263,17 @@ HSTORE type The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension -by default when it is detected that the target database has the HSTORE -type set up for use. In other words, when the dialect makes the first +by default when psycopg2 version 2.4 or greater is used, and +it is detected that the target database has the HSTORE type set up for use. +In other words, when the dialect makes the first connection, a sequence like the following is performed: 1. Request the available HSTORE oids using ``psycopg2.extras.HstoreAdapter.get_oids()``. If this function returns a list of HSTORE identifiers, we then determine that the ``HSTORE`` extension is present. + This function is **skipped** if the version of psycopg2 installed is + less than version 2.4. 2. If the ``use_native_hstore`` flag is at its default of ``True``, and we've detected that ``HSTORE`` oids are available, the @@ -259,9 +312,14 @@ from ... import types as sqltypes from .base import PGDialect, PGCompiler, \ PGIdentifierPreparer, PGExecutionContext, \ ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ - _INT_TYPES + _INT_TYPES, UUID from .hstore import HSTORE -from .json import JSON +from .json import JSON, JSONB + +try: + from uuid import UUID as _python_UUID +except ImportError: + _python_UUID = None logger = logging.getLogger('sqlalchemy.dialects.postgresql') @@ -326,6 +384,35 @@ class _PGJSON(JSON): else: return super(_PGJSON, self).result_processor(dialect, coltype) + +class _PGJSONB(JSONB): + + def result_processor(self, dialect, coltype): + if dialect._has_native_jsonb: + return None + else: + return super(_PGJSONB, self).result_processor(dialect, coltype) + + +class _PGUUID(UUID): + def bind_processor(self, dialect): + if not self.as_uuid and dialect.use_native_uuid: + nonetype = type(None) + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + return process + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + def process(value): + if value is not None: + value = str(value) + return value + return process + # When we're handed literal SQL, ensure it's a SELECT query. Since # 8.3, combining cursors and "FOR UPDATE" has been fine. SERVER_SIDE_CURSOR_RE = re.compile( @@ -416,6 +503,7 @@ class PGDialect_psycopg2(PGDialect): _has_native_hstore = False _has_native_json = False + _has_native_jsonb = False colspecs = util.update_copy( PGDialect.colspecs, @@ -424,18 +512,21 @@ class PGDialect_psycopg2(PGDialect): ENUM: _PGEnum, # needs force_unicode sqltypes.Enum: _PGEnum, # needs force_unicode HSTORE: _PGHStore, - JSON: _PGJSON + JSON: _PGJSON, + JSONB: _PGJSONB, + UUID: _PGUUID } ) def __init__(self, server_side_cursors=False, use_native_unicode=True, client_encoding=None, - use_native_hstore=True, + use_native_hstore=True, use_native_uuid=True, **kwargs): PGDialect.__init__(self, **kwargs) self.server_side_cursors = server_side_cursors self.use_native_unicode = use_native_unicode self.use_native_hstore = use_native_hstore + self.use_native_uuid = use_native_uuid self.supports_unicode_binds = use_native_unicode self.client_encoding = client_encoding if self.dbapi and hasattr(self.dbapi, '__version__'): @@ -453,6 +544,7 @@ class PGDialect_psycopg2(PGDialect): self._hstore_oids(connection.connection) \ is not None self._has_native_json = self.psycopg2_version >= (2, 5) + self._has_native_jsonb = self.psycopg2_version >= (2, 5, 4) # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9 self.supports_sane_multi_rowcount = self.psycopg2_version >= (2, 0, 9) @@ -462,9 +554,19 @@ class PGDialect_psycopg2(PGDialect): import psycopg2 return psycopg2 + @classmethod + def _psycopg2_extensions(cls): + from psycopg2 import extensions + return extensions + + @classmethod + def _psycopg2_extras(cls): + from psycopg2 import extras + return extras + @util.memoized_property def _isolation_lookup(self): - from psycopg2 import extensions + extensions = self._psycopg2_extensions() return { 'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT, 'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED, @@ -486,7 +588,8 @@ class PGDialect_psycopg2(PGDialect): connection.set_isolation_level(level) def on_connect(self): - from psycopg2 import extras, extensions + extras = self._psycopg2_extras() + extensions = self._psycopg2_extensions() fns = [] if self.client_encoding is not None: @@ -499,6 +602,11 @@ class PGDialect_psycopg2(PGDialect): self.set_isolation_level(conn, self.isolation_level) fns.append(on_connect) + if self.dbapi and self.use_native_uuid: + def on_connect(conn): + extras.register_uuid(None, conn) + fns.append(on_connect) + if self.dbapi and self.use_native_unicode: def on_connect(conn): extensions.register_type(extensions.UNICODE, conn) @@ -510,19 +618,22 @@ class PGDialect_psycopg2(PGDialect): hstore_oids = self._hstore_oids(conn) if hstore_oids is not None: oid, array_oid = hstore_oids + kw = {'oid': oid} if util.py2k: - extras.register_hstore(conn, oid=oid, - array_oid=array_oid, - unicode=True) - else: - extras.register_hstore(conn, oid=oid, - array_oid=array_oid) + kw['unicode'] = True + if self.psycopg2_version >= (2, 4, 3): + kw['array_oid'] = array_oid + extras.register_hstore(conn, **kw) fns.append(on_connect) if self.dbapi and self._json_deserializer: def on_connect(conn): - extras.register_default_json( - conn, loads=self._json_deserializer) + if self._has_native_json: + extras.register_default_json( + conn, loads=self._json_deserializer) + if self._has_native_jsonb: + extras.register_default_jsonb( + conn, loads=self._json_deserializer) fns.append(on_connect) if fns: @@ -536,7 +647,7 @@ class PGDialect_psycopg2(PGDialect): @util.memoized_instancemethod def _hstore_oids(self, conn): if self.psycopg2_version >= (2, 4): - from psycopg2 import extras + extras = self._psycopg2_extras() oids = extras.HstoreAdapter.get_oids(conn) if oids is not None and oids[0]: return oids[0:2] diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py new file mode 100644 index 000000000..f5c475d90 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -0,0 +1,49 @@ +# testing/engines.py +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +.. dialect:: postgresql+psycopg2cffi + :name: psycopg2cffi + :dbapi: psycopg2cffi + :connectstring: \ +postgresql+psycopg2cffi://user:password@host:port/dbname\ +[?key=value&key=value...] + :url: http://pypi.python.org/pypi/psycopg2cffi/ + +``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C +layer. This makes it suitable for use in e.g. PyPy. Documentation +is as per ``psycopg2``. + +.. versionadded:: 1.0.0 + +.. seealso:: + + :mod:`sqlalchemy.dialects.postgresql.psycopg2` + +""" +from .psycopg2 import PGDialect_psycopg2 + + +class PGDialect_psycopg2cffi(PGDialect_psycopg2): + driver = 'psycopg2cffi' + supports_unicode_statements = True + + @classmethod + def dbapi(cls): + return __import__('psycopg2cffi') + + @classmethod + def _psycopg2_extensions(cls): + root = __import__('psycopg2cffi', fromlist=['extensions']) + return root.extensions + + @classmethod + def _psycopg2_extras(cls): + root = __import__('psycopg2cffi', fromlist=['extras']) + return root.extras + + +dialect = PGDialect_psycopg2cffi diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py index 3ebd0135f..00c67d170 100644 --- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -1,5 +1,5 @@ # postgresql/pypostgresql.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 28f80d000..59c35c871 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -1,4 +1,4 @@ -# Copyright (C) 2013-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2013-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py index 00b428f84..1b542152c 100644 --- a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -1,5 +1,5 @@ # postgresql/zxjdbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py index 0eceaa537..608630a25 100644 --- a/lib/sqlalchemy/dialects/sqlite/__init__.py +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -1,11 +1,11 @@ # sqlite/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy.dialects.sqlite import base, pysqlite +from sqlalchemy.dialects.sqlite import base, pysqlite, pysqlcipher # default dialect base.dialect = pysqlite.dialect diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 335b35c94..0254690b4 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1,5 +1,5 @@ # sqlite/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -9,6 +9,7 @@ .. dialect:: sqlite :name: SQLite +.. _sqlite_datetime: Date and Time Types ------------------- @@ -23,6 +24,20 @@ These types represent dates and times as ISO formatted strings, which also nicely support ordering. There's no reliance on typical "libc" internals for these functions so historical dates are fully supported. +Ensuring Text affinity +^^^^^^^^^^^^^^^^^^^^^^ + +The DDL rendered for these types is the standard ``DATE``, ``TIME`` +and ``DATETIME`` indicators. However, custom storage formats can also be +applied to these types. When the +storage format is detected as containing no alpha characters, the DDL for +these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``, +so that the column continues to have textual affinity. + +.. seealso:: + + `Type Affinity <http://www.sqlite.org/datatype3.html#affinity>`_ - in the SQLite documentation + .. _sqlite_autoincrement: SQLite Auto Incrementing Behavior @@ -92,8 +107,10 @@ The following subsections introduce areas that are impacted by SQLite's file-based architecture and additionally will usually require workarounds to work when using the pysqlite driver. +.. _sqlite_isolation_level: + Transaction Isolation Level -=========================== +---------------------------- SQLite supports "transaction isolation" in a non-standard way, along two axes. One is that of the `PRAGMA read_uncommitted <http://www.sqlite.org/pragma.html#pragma_read_uncommitted>`_ @@ -126,7 +143,7 @@ by *not even emitting BEGIN* until the first write operation. for techniques to work around this behavior. SAVEPOINT Support -================= +---------------------------- SQLite supports SAVEPOINTs, which only function once a transaction is begun. SQLAlchemy's SAVEPOINT support is available using the @@ -142,7 +159,7 @@ won't work at all with pysqlite unless workarounds are taken. for techniques to work around this behavior. Transactional DDL -================= +---------------------------- The SQLite database supports transactional :term:`DDL` as well. In this case, the pysqlite driver is not only failing to start transactions, @@ -186,6 +203,15 @@ new connections through the usage of events:: cursor.execute("PRAGMA foreign_keys=ON") cursor.close() +.. warning:: + + When SQLite foreign keys are enabled, it is **not possible** + to emit CREATE or DROP statements for tables that contain + mutually-dependent foreign key constraints; + to emit the DDL for these tables requires that ALTER TABLE be used to + create or drop these constraints separately, for which SQLite has + no support. + .. seealso:: `SQLite Foreign Key Support <http://www.sqlite.org/foreignkeys.html>`_ @@ -193,6 +219,9 @@ new connections through the usage of events:: :ref:`event_toplevel` - SQLAlchemy event API. + :ref:`use_alter` - more information on SQLAlchemy's facilities for handling + mutually-dependent foreign key constraints. + .. _sqlite_type_reflection: Type Reflection @@ -243,6 +272,26 @@ lookup is used instead: .. versionadded:: 0.9.3 Support for SQLite type affinity rules when reflecting columns. + +.. _sqlite_partial_index: + +Partial Indexes +--------------- + +A partial index, e.g. one which uses a WHERE clause, can be specified +with the DDL system using the argument ``sqlite_where``:: + + tbl = Table('testtbl', m, Column('data', Integer)) + idx = Index('test_idx1', tbl.c.data, + sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + +The index will be rendered at create time as:: + + CREATE INDEX test_idx1 ON testtbl (data) + WHERE data > 5 AND data < 10 + +.. versionadded:: 0.9.9 + """ import datetime @@ -255,7 +304,7 @@ from ... import util from ...engine import default, reflection from ...sql import compiler -from ...types import (BLOB, BOOLEAN, CHAR, DATE, DECIMAL, FLOAT, +from ...types import (BLOB, BOOLEAN, CHAR, DECIMAL, FLOAT, INTEGER, REAL, NUMERIC, SMALLINT, TEXT, TIMESTAMP, VARCHAR) @@ -271,6 +320,25 @@ class _DateTimeMixin(object): if storage_format is not None: self._storage_format = storage_format + @property + def format_is_text_affinity(self): + """return True if the storage format will automatically imply + a TEXT affinity. + + If the storage format contains no non-numeric characters, + it will imply a NUMERIC storage format on SQLite; in this case, + the type will generate its DDL as DATE_CHAR, DATETIME_CHAR, + TIME_CHAR. + + .. versionadded:: 1.0.0 + + """ + spec = self._storage_format % { + "year": 0, "month": 0, "day": 0, "hour": 0, + "minute": 0, "second": 0, "microsecond": 0 + } + return bool(re.search(r'[^0-9]', spec)) + def adapt(self, cls, **kw): if issubclass(cls, _DateTimeMixin): if self._storage_format: @@ -526,7 +594,9 @@ ischema_names = { 'BOOLEAN': sqltypes.BOOLEAN, 'CHAR': sqltypes.CHAR, 'DATE': sqltypes.DATE, + 'DATE_CHAR': sqltypes.DATE, 'DATETIME': sqltypes.DATETIME, + 'DATETIME_CHAR': sqltypes.DATETIME, 'DOUBLE': sqltypes.FLOAT, 'DECIMAL': sqltypes.DECIMAL, 'FLOAT': sqltypes.FLOAT, @@ -537,6 +607,7 @@ ischema_names = { 'SMALLINT': sqltypes.SMALLINT, 'TEXT': sqltypes.TEXT, 'TIME': sqltypes.TIME, + 'TIME_CHAR': sqltypes.TIME, 'TIMESTAMP': sqltypes.TIMESTAMP, 'VARCHAR': sqltypes.VARCHAR, 'NVARCHAR': sqltypes.NVARCHAR, @@ -611,7 +682,8 @@ class SQLiteCompiler(compiler.SQLCompiler): class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - coltype = self.dialect.type_compiler.process(column.type) + coltype = self.dialect.type_compiler.process( + column.type, type_expression=column) colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: @@ -646,8 +718,8 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): def visit_foreign_key_constraint(self, constraint): - local_table = list(constraint._elements.values())[0].parent.table - remote_table = list(constraint._elements.values())[0].column.table + local_table = constraint.elements[0].parent.table + remote_table = constraint.elements[0].column.table if local_table.schema != remote_table.schema: return None @@ -662,14 +734,46 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return preparer.format_table(table, use_schema=False) def visit_create_index(self, create): - return super(SQLiteDDLCompiler, self).visit_create_index( + index = create.element + + text = super(SQLiteDDLCompiler, self).visit_create_index( create, include_table_schema=False) + whereclause = index.dialect_options["sqlite"]["where"] + if whereclause is not None: + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, + literal_binds=True) + text += " WHERE " + where_compiled + + return text + class SQLiteTypeCompiler(compiler.GenericTypeCompiler): - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) + def visit_DATETIME(self, type_, **kw): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) + else: + return "DATETIME_CHAR" + + def visit_DATE(self, type_, **kw): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_DATE(type_) + else: + return "DATE_CHAR" + + def visit_TIME(self, type_, **kw): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_TIME(type_) + else: + return "TIME_CHAR" + class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = set([ @@ -750,7 +854,10 @@ class SQLiteDialect(default.DefaultDialect): construct_arguments = [ (sa_schema.Table, { "autoincrement": False - }) + }), + (sa_schema.Index, { + "where": None, + }), ] _broken_fk_pragma_quotes = False @@ -855,22 +962,9 @@ class SQLiteDialect(default.DefaultDialect): return [row[0] for row in rs] def has_table(self, connection, table_name, schema=None): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%stable_info(%s)" % (pragma, qtable) - cursor = _pragma_cursor(connection.execute(statement)) - row = cursor.fetchone() - - # consume remaining rows, to work around - # http://www.sqlite.org/cvstrac/tktview?tn=1884 - while not cursor.closed and cursor.fetchone() is not None: - pass - - return row is not None + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema) + return bool(info) @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -912,18 +1006,11 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%stable_info(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema) - rows = c.fetchall() columns = [] - for row in rows: + for row in info: (name, type_, nullable, default, primary_key) = ( row[1], row[2].upper(), not row[3], row[4], row[5]) @@ -1010,92 +1097,192 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%sforeign_key_list(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) - fkeys = [] + # sqlite makes this *extremely difficult*. + # First, use the pragma to get the actual FKs. + pragma_fks = self._get_table_pragma( + connection, "foreign_key_list", + table_name, schema=schema + ) + fks = {} - while True: - row = c.fetchone() - if row is None: - break + + for row in pragma_fks: (numerical_id, rtbl, lcol, rcol) = ( row[0], row[2], row[3], row[4]) - self._parse_fk(fks, fkeys, numerical_id, rtbl, lcol, rcol) - return fkeys + if rcol is None: + rcol = lcol - def _parse_fk(self, fks, fkeys, numerical_id, rtbl, lcol, rcol): - # sqlite won't return rcol if the table was created with REFERENCES - # <tablename>, no col - if rcol is None: - rcol = lcol + if self._broken_fk_pragma_quotes: + rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) - if self._broken_fk_pragma_quotes: - rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) + if numerical_id in fks: + fk = fks[numerical_id] + else: + fk = fks[numerical_id] = { + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': rtbl, + 'referred_columns': [], + } + fks[numerical_id] = fk - try: - fk = fks[numerical_id] - except KeyError: - fk = { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': rtbl, - 'referred_columns': [], - } - fkeys.append(fk) - fks[numerical_id] = fk - - if lcol not in fk['constrained_columns']: fk['constrained_columns'].append(lcol) - if rcol not in fk['referred_columns']: fk['referred_columns'].append(rcol) - return fk + + def fk_sig(constrained_columns, referred_table, referred_columns): + return tuple(constrained_columns) + (referred_table,) + \ + tuple(referred_columns) + + # then, parse the actual SQL and attempt to find DDL that matches + # the names as well. SQLite saves the DDL in whatever format + # it was typed in as, so need to be liberal here. + + keys_by_signature = dict( + ( + fk_sig( + fk['constrained_columns'], + fk['referred_table'], fk['referred_columns']), + fk + ) for fk in fks.values() + ) + + table_data = self._get_table_sql(connection, table_name, schema=schema) + if table_data is None: + # system tables, etc. + return [] + + def parse_fks(): + FK_PATTERN = ( + '(?:CONSTRAINT (\w+) +)?' + 'FOREIGN KEY *\( *(.+?) *\) +' + 'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\)' + ) + + for match in re.finditer(FK_PATTERN, table_data, re.I): + ( + constraint_name, constrained_columns, + referred_quoted_name, referred_name, + referred_columns) = match.group(1, 2, 3, 4, 5) + constrained_columns = list( + self._find_cols_in_sig(constrained_columns)) + if not referred_columns: + referred_columns = constrained_columns + else: + referred_columns = list( + self._find_cols_in_sig(referred_columns)) + referred_name = referred_quoted_name or referred_name + yield ( + constraint_name, constrained_columns, + referred_name, referred_columns) + fkeys = [] + + for ( + constraint_name, constrained_columns, + referred_name, referred_columns) in parse_fks(): + sig = fk_sig( + constrained_columns, referred_name, referred_columns) + if sig not in keys_by_signature: + util.warn( + "WARNING: SQL-parsed foreign key constraint " + "'%s' could not be located in PRAGMA " + "foreign_keys for table %s" % ( + sig, + table_name + )) + continue + key = keys_by_signature.pop(sig) + key['name'] = constraint_name + fkeys.append(key) + # assume the remainders are the unnamed, inline constraints, just + # use them as is as it's extremely difficult to parse inline + # constraints + fkeys.extend(keys_by_signature.values()) + return fkeys + + def _find_cols_in_sig(self, sig): + for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): + yield match.group(1) or match.group(2) + + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + + auto_index_by_sig = {} + for idx in self.get_indexes( + connection, table_name, schema=schema, + include_auto_indexes=True, **kw): + if not idx['name'].startswith("sqlite_autoindex"): + continue + sig = tuple(idx['column_names']) + auto_index_by_sig[sig] = idx + + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw) + if not table_data: + return [] + + unique_constraints = [] + + def parse_uqs(): + UNIQUE_PATTERN = '(?:CONSTRAINT (\w+) +)?UNIQUE *\((.+?)\)' + INLINE_UNIQUE_PATTERN = ( + '(?:(".+?")|([a-z0-9]+)) ' + '+[a-z0-9_ ]+? +UNIQUE') + + for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): + name, cols = match.group(1, 2) + yield name, list(self._find_cols_in_sig(cols)) + + # we need to match inlines as well, as we seek to differentiate + # a UNIQUE constraint from a UNIQUE INDEX, even though these + # are kind of the same thing :) + for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): + cols = list( + self._find_cols_in_sig(match.group(1) or match.group(2))) + yield None, cols + + for name, cols in parse_uqs(): + sig = tuple(cols) + if sig in auto_index_by_sig: + auto_index_by_sig.pop(sig) + parsed_constraint = { + 'name': name, + 'column_names': cols + } + unique_constraints.append(parsed_constraint) + # NOTE: auto_index_by_sig might not be empty here, + # the PRIMARY KEY may have an entry. + return unique_constraints @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - include_auto_indexes = kw.pop('include_auto_indexes', False) - qtable = quote(table_name) - statement = "%sindex_list(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) + pragma_indexes = self._get_table_pragma( + connection, "index_list", table_name, schema=schema) indexes = [] - while True: - row = c.fetchone() - if row is None: - break + + include_auto_indexes = kw.pop('include_auto_indexes', False) + for row in pragma_indexes: # ignore implicit primary key index. # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html - elif (not include_auto_indexes and - row[1].startswith('sqlite_autoindex')): + if (not include_auto_indexes and + row[1].startswith('sqlite_autoindex')): continue indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + # loop thru unique indexes to get the column names. for idx in indexes: - statement = "%sindex_info(%s)" % (pragma, quote(idx['name'])) - c = connection.execute(statement) - cols = idx['column_names'] - while True: - row = c.fetchone() - if row is None: - break - cols.append(row[2]) + pragma_index = self._get_table_pragma( + connection, "index_info", idx['name']) + + for row in pragma_index: + idx['column_names'].append(row[2]) return indexes @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def _get_table_sql(self, connection, table_name, schema=None, **kw): try: s = ("SELECT sql FROM " " (SELECT * FROM sqlite_master UNION ALL " @@ -1107,27 +1294,22 @@ class SQLiteDialect(default.DefaultDialect): s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " "AND type = 'table'") % table_name rs = connection.execute(s) - row = rs.fetchone() - if row is None: - # sqlite won't return the schema for the sqlite_master or - # sqlite_temp_master tables from this query. These tables - # don't have any unique constraints anyway. - return [] - table_data = row[0] - - UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)' - return [ - {'name': name, - 'column_names': [col.strip(' "') for col in cols.split(',')]} - for name, cols in re.findall(UNIQUE_PATTERN, table_data) - ] - - -def _pragma_cursor(cursor): - """work around SQLite issue whereby cursor.description - is blank when PRAGMA returns no rows.""" + return rs.scalar() - if cursor.closed: - cursor.fetchone = lambda: None - cursor.fetchall = lambda: [] - return cursor + def _get_table_pragma(self, connection, pragma, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + statement = "PRAGMA %s." % quote(schema) + else: + statement = "PRAGMA " + qtable = quote(table_name) + statement = "%s%s(%s)" % (statement, pragma, qtable) + cursor = connection.execute(statement) + if not cursor._soft_closed: + # work around SQLite issue whereby cursor.description + # is blank when PRAGMA returns no rows: + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + result = cursor.fetchall() + else: + result = [] + return result diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py new file mode 100644 index 000000000..9166e36bc --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -0,0 +1,116 @@ +# sqlite/pysqlcipher.py +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sqlite+pysqlcipher + :name: pysqlcipher + :dbapi: pysqlcipher + :connectstring: sqlite+pysqlcipher://:passphrase/file_path[?kdf_iter=<iter>] + :url: https://pypi.python.org/pypi/pysqlcipher + + ``pysqlcipher`` is a fork of the standard ``pysqlite`` driver to make + use of the `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend. + + .. versionadded:: 0.9.9 + +Driver +------ + +The driver here is the `pysqlcipher <https://pypi.python.org/pypi/pysqlcipher>`_ +driver, which makes use of the SQLCipher engine. This system essentially +introduces new PRAGMA commands to SQLite which allows the setting of a +passphrase and other encryption parameters, allowing the database +file to be encrypted. + +Connect Strings +--------------- + +The format of the connect string is in every way the same as that +of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the +"password" field is now accepted, which should contain a passphrase:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db') + +For an absolute file path, two leading slashes should be used for the +database name:: + + e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db') + +A selection of additional encryption-related pragmas supported by SQLCipher +as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed +in the query string, and will result in that PRAGMA being called for each +new connection. Currently, ``cipher``, ``kdf_iter`` +``cipher_page_size`` and ``cipher_use_hmac`` are supported:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000') + + +Pooling Behavior +---------------- + +The driver makes a change to the default pool behavior of pysqlite +as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver +has been observed to be significantly slower on connection than the +pysqlite driver, most likely due to the encryption overhead, so the +dialect here defaults to using the :class:`.SingletonThreadPool` +implementation, +instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool +implementation is entirely configurable using the +:paramref:`.create_engine.poolclass` parameter; the :class:`.StaticPool` may +be more feasible for single-threaded use, or :class:`.NullPool` may be used +to prevent unencrypted connections from being held open for long periods of +time, at the expense of slower startup time for new connections. + + +""" +from __future__ import absolute_import +from .pysqlite import SQLiteDialect_pysqlite +from ...engine import url as _url +from ... import pool + + +class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): + driver = 'pysqlcipher' + + pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac') + + @classmethod + def dbapi(cls): + from pysqlcipher import dbapi2 as sqlcipher + return sqlcipher + + @classmethod + def get_pool_class(cls, url): + return pool.SingletonThreadPool + + def connect(self, *cargs, **cparams): + passphrase = cparams.pop('passphrase', '') + + pragmas = dict( + (key, cparams.pop(key, None)) for key in + self.pragmas + ) + + conn = super(SQLiteDialect_pysqlcipher, self).\ + connect(*cargs, **cparams) + conn.execute('pragma key="%s"' % passphrase) + for prag, value in pragmas.items(): + if value is not None: + conn.execute('pragma %s=%s' % (prag, value)) + + return conn + + def create_connect_args(self, url): + super_url = _url.URL( + url.drivername, username=url.username, + host=url.host, database=url.database, query=url.query) + c_args, opts = super(SQLiteDialect_pysqlcipher, self).\ + create_connect_args(super_url) + opts['passphrase'] = url.password + return c_args, opts + +dialect = SQLiteDialect_pysqlcipher diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 62c19d145..e1c443477 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -1,5 +1,5 @@ # sqlite/pysqlite.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py index eb313592b..0c55de1d6 100644 --- a/lib/sqlalchemy/dialects/sybase/__init__.py +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -1,5 +1,5 @@ # sybase/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index f65a76a27..57213382e 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -1,5 +1,5 @@ # sybase/base.py -# Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2010-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # get_select_precolumns(), limit_clause() implementation # copyright (C) 2007 Fisch Asset Management @@ -146,40 +146,40 @@ class IMAGE(sqltypes.LargeBinary): class SybaseTypeCompiler(compiler.GenericTypeCompiler): - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_IMAGE(type_) - def visit_boolean(self, type_): + def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) - def visit_unicode(self, type_): + def visit_unicode(self, type_, **kw): return self.visit_NVARCHAR(type_) - def visit_UNICHAR(self, type_): + def visit_UNICHAR(self, type_, **kw): return "UNICHAR(%d)" % type_.length - def visit_UNIVARCHAR(self, type_): + def visit_UNIVARCHAR(self, type_, **kw): return "UNIVARCHAR(%d)" % type_.length - def visit_UNITEXT(self, type_): + def visit_UNITEXT(self, type_, **kw): return "UNITEXT" - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): return "TINYINT" - def visit_IMAGE(self, type_): + def visit_IMAGE(self, type_, **kw): return "IMAGE" - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): return "BIT" - def visit_MONEY(self, type_): + def visit_MONEY(self, type_, **kw): return "MONEY" - def visit_SMALLMONEY(self, type_): + def visit_SMALLMONEY(self, type_, **kw): return "SMALLMONEY" - def visit_UNIQUEIDENTIFIER(self, type_): + def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" ischema_names = { @@ -377,7 +377,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler): class SybaseDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) if column.table is None: raise exc.CompileError( @@ -434,6 +435,7 @@ class SybaseDialect(default.DefaultDialect): supports_native_boolean = False supports_unicode_binds = False postfetch_lastrowid = True + supports_simple_order_by_label = False colspecs = {} ischema_names = ischema_names diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py index 373bea05d..240b634d4 100644 --- a/lib/sqlalchemy/dialects/sybase/mxodbc.py +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -1,5 +1,5 @@ # sybase/mxodbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index cb76d1379..168997074 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -1,5 +1,5 @@ # sybase/pyodbc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py index 6843eb480..a30739444 100644 --- a/lib/sqlalchemy/dialects/sybase/pysybase.py +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -1,5 +1,5 @@ # sybase/pysybase.py -# Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2010-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 68145f5cd..0678dd201 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -1,5 +1,5 @@ # engine/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -72,6 +72,7 @@ from .base import ( ) from .result import ( + BaseRowProxy, BufferedColumnResultProxy, BufferedColumnRow, BufferedRowResultProxy, @@ -256,14 +257,26 @@ def create_engine(*args, **kwargs): Behavior here varies per backend, and individual dialects should be consulted directly. + Note that the isolation level can also be set on a per-:class:`.Connection` + basis as well, using the + :paramref:`.Connection.execution_options.isolation_level` + feature. + .. seealso:: - :ref:`SQLite Concurrency <sqlite_concurrency>` + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + :ref:`SQLite Transaction Isolation <sqlite_isolation_level>` :ref:`Postgresql Transaction Isolation <postgresql_isolation_level>` :ref:`MySQL Transaction Isolation <mysql_isolation_level>` + :ref:`session_transaction_isolation` - for the ORM + :param label_length=None: optional integer value which limits the size of dynamically generated column labels to that many characters. If less than 6, labels are generated as @@ -292,6 +305,17 @@ def create_engine(*args, **kwargs): be used instead. Can be used for testing of DBAPIs as well as to inject "mock" DBAPI implementations into the :class:`.Engine`. + :param paramstyle=None: The `paramstyle <http://legacy.python.org/dev/peps/pep-0249/#paramstyle>`_ + to use when rendering bound parameters. This style defaults to the + one recommended by the DBAPI itself, which is retrieved from the + ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept + more than one paramstyle, and in particular it may be desirable + to change a "named" paramstyle into a "positional" one, or vice versa. + When this attribute is passed, it should be one of the values + ``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or + ``"pyformat"``, and should correspond to a parameter style known + to be supported by the DBAPI in use. + :param pool=None: an already-constructed instance of :class:`~sqlalchemy.pool.Pool`, such as a :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dd82be1d1..5921ab9ba 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1,5 +1,5 @@ # engine/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -201,14 +201,19 @@ class Connection(Connectable): used by the ORM internally supersedes a cache dictionary specified here. - :param isolation_level: Available on: Connection. + :param isolation_level: Available on: :class:`.Connection`. Set the transaction isolation level for - the lifespan of this connection. Valid values include - those string values accepted by the ``isolation_level`` - parameter passed to :func:`.create_engine`, and are - database specific, including those for :ref:`sqlite_toplevel`, - :ref:`postgresql_toplevel` - see those dialect's documentation - for further info. + the lifespan of this :class:`.Connection` object (*not* the + underyling DBAPI connection, for which the level is reset + to its original setting upon termination of this + :class:`.Connection` object). + + Valid values include + those string values accepted by the + :paramref:`.create_engine.isolation_level` + parameter passed to :func:`.create_engine`. These levels are + semi-database specific; see individual dialect documentation for + valid levels. Note that this option necessarily affects the underlying DBAPI connection for the lifespan of the originating @@ -217,6 +222,41 @@ class Connection(Connectable): is returned to the connection pool, i.e. the :meth:`.Connection.close` method is called. + .. warning:: The ``isolation_level`` execution option should + **not** be used when a transaction is already established, that + is, the :meth:`.Connection.begin` method or similar has been + called. A database cannot change the isolation level on a + transaction in progress, and different DBAPIs and/or + SQLAlchemy dialects may implicitly roll back or commit + the transaction, or not affect the connection at all. + + .. versionchanged:: 0.9.9 A warning is emitted when the + ``isolation_level`` execution option is used after a + transaction has been started with :meth:`.Connection.begin` + or similar. + + .. note:: The ``isolation_level`` execution option is implicitly + reset if the :class:`.Connection` is invalidated, e.g. via + the :meth:`.Connection.invalidate` method, or if a + disconnection error occurs. The new connection produced after + the invalidation will not have the isolation level re-applied + to it automatically. + + .. seealso:: + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :meth:`.Connection.get_isolation_level` - view current level + + :ref:`SQLite Transaction Isolation <sqlite_isolation_level>` + + :ref:`Postgresql Transaction Isolation <postgresql_isolation_level>` + + :ref:`MySQL Transaction Isolation <mysql_isolation_level>` + + :ref:`session_transaction_isolation` - for the ORM + :param no_parameters: When ``True``, if the final parameter list or dictionary is totally empty, will invoke the statement on the cursor as ``cursor.execute(statement)``, @@ -260,23 +300,97 @@ class Connection(Connectable): @property def connection(self): - "The underlying DB-API connection managed by this Connection." + """The underlying DB-API connection managed by this Connection. + + .. seealso:: + + + :ref:`dbapi_connections` + + """ try: return self.__connection except AttributeError: - return self._revalidate_connection() + try: + return self._revalidate_connection() + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def get_isolation_level(self): + """Return the current isolation level assigned to this + :class:`.Connection`. + + This will typically be the default isolation level as determined + by the dialect, unless if the + :paramref:`.Connection.execution_options.isolation_level` + feature has been used to alter the isolation level on a + per-:class:`.Connection` basis. + + This attribute will typically perform a live SQL operation in order + to procure the current isolation level, so the value returned is the + actual level on the underlying DBAPI connection regardless of how + this state was set. Compare to the + :attr:`.Connection.default_isolation_level` accessor + which returns the dialect-level setting without performing a SQL + query. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + """ + try: + return self.dialect.get_isolation_level(self.connection) + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + + @property + def default_isolation_level(self): + """The default isolation level assigned to this :class:`.Connection`. + + This is the isolation level setting that the :class:`.Connection` + has when first procured via the :meth:`.Engine.connect` method. + This level stays in place until the + :paramref:`.Connection.execution_options.isolation_level` is used + to change the setting on a per-:class:`.Connection` basis. + + Unlike :meth:`.Connection.get_isolation_level`, this attribute is set + ahead of time from the first connection procured by the dialect, + so SQL query is not invoked when this accessor is called. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + """ + return self.dialect.default_isolation_level def _revalidate_connection(self): if self.__branch_from: return self.__branch_from._revalidate_connection() - if self.__can_reconnect and self.__invalid: if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " "transaction is rolled back") - self.__connection = self.engine.raw_connection() + self.__connection = self.engine.raw_connection(_connection=self) self.__invalid = False return self.__connection raise exc.ResourceClosedError("This Connection is closed") @@ -741,7 +855,7 @@ class Connection(Connectable): a subclass of :class:`.Executable`, such as a :func:`~.expression.select` construct * a :class:`.FunctionElement`, such as that generated - by :attr:`.func`, will be automatically wrapped in + by :data:`.func`, will be automatically wrapped in a SELECT statement, which is then executed. * a :class:`.DDLElement` object * a :class:`.DefaultGenerator` object @@ -877,9 +991,8 @@ class Connection(Connectable): dialect = self.dialect if 'compiled_cache' in self._execution_options: key = dialect, elem, tuple(sorted(keys)), len(distilled_params) > 1 - if key in self._execution_options['compiled_cache']: - compiled_sql = self._execution_options['compiled_cache'][key] - else: + compiled_sql = self._execution_options['compiled_cache'].get(key) + if compiled_sql is None: compiled_sql = elem.compile( dialect=dialect, column_keys=keys, inline=len(distilled_params) > 1) @@ -959,9 +1072,10 @@ class Connection(Connectable): context = constructor(dialect, self, conn, *args) except Exception as e: - self._handle_dbapi_exception(e, - util.text_type(statement), parameters, - None, None) + self._handle_dbapi_exception( + e, + util.text_type(statement), parameters, + None, None) if context.compiled: context.pre_exec() @@ -985,36 +1099,39 @@ class Connection(Connectable): "%r", sql_util._repr_params(parameters, batches=10) ) + + evt_handled = False try: if context.executemany: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_executemany: - if fn(cursor, statement, parameters, context): - break - else: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_executemany: + if fn(cursor, statement, parameters, context): + evt_handled = True + break + if not evt_handled: self.dialect.do_executemany( cursor, statement, parameters, context) - elif not parameters and context.no_parameters: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute_no_params: - if fn(cursor, statement, context): - break - else: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute_no_params: + if fn(cursor, statement, context): + evt_handled = True + break + if not evt_handled: self.dialect.do_execute_no_params( cursor, statement, context) - else: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute: - if fn(cursor, statement, parameters, context): - break - else: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute: + if fn(cursor, statement, parameters, context): + evt_handled = True + break + if not evt_handled: self.dialect.do_execute( cursor, statement, @@ -1038,36 +1155,17 @@ class Connection(Connectable): if context.compiled: context.post_exec() - if context.isinsert and not context.executemany: - context.post_insert() - - # create a resultproxy, get rowcount/implicit RETURNING - # rows, close cursor if no further results pending - result = context.get_result_proxy() - if context.isinsert: - if context._is_implicit_returning: - context._fetch_implicit_returning(result) - result.close(_autoclose_connection=False) - result._metadata = None - elif not context._is_explicit_returning: - result.close(_autoclose_connection=False) - result._metadata = None - elif context.isupdate and context._is_implicit_returning: - context._fetch_implicit_update_returning(result) - result.close(_autoclose_connection=False) - result._metadata = None - - elif result._metadata is None: - # no results, get rowcount - # (which requires open cursor on some drivers - # such as kintersbasdb, mxodbc), - result.rowcount - result.close(_autoclose_connection=False) + if context.is_crud: + result = context._setup_crud_result_proxy() + else: + result = context.get_result_proxy() + if result._metadata is None: + result._soft_close(_autoclose_connection=False) if context.should_autocommit and self._root.__transaction is None: self._root._commit_impl(autocommit=True) - if result.closed and self.should_close_with_result: + if result._soft_closed and self.should_close_with_result: self.close() return result @@ -1149,7 +1247,10 @@ class Connection(Connectable): self._is_disconnect = \ isinstance(e, self.dialect.dbapi.Error) and \ not self.closed and \ - self.dialect.is_disconnect(e, self.__connection, cursor) + self.dialect.is_disconnect( + e, + self.__connection if not self.invalidated else None, + cursor) if context: context.is_disconnect = self._is_disconnect @@ -1194,7 +1295,8 @@ class Connection(Connectable): # new handle_error event ctx = ExceptionContextImpl( - e, sqlalchemy_exception, self, cursor, statement, + e, sqlalchemy_exception, self.engine, + self, cursor, statement, parameters, context, self._is_disconnect) for fn in self.dispatch.handle_error: @@ -1236,12 +1338,65 @@ class Connection(Connectable): del self._reentrant_error if self._is_disconnect: del self._is_disconnect - dbapi_conn_wrapper = self.connection - self.engine.pool._invalidate(dbapi_conn_wrapper, e) - self.invalidate(e) + if not self.invalidated: + dbapi_conn_wrapper = self.__connection + self.engine.pool._invalidate(dbapi_conn_wrapper, e) + self.invalidate(e) if self.should_close_with_result: self.close() + @classmethod + def _handle_dbapi_exception_noconnection(cls, e, dialect, engine): + + exc_info = sys.exc_info() + + is_disconnect = dialect.is_disconnect(e, None, None) + + should_wrap = isinstance(e, dialect.dbapi.Error) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + None, + None, + e, + dialect.dbapi.Error, + connection_invalidated=is_disconnect) + else: + sqlalchemy_exception = None + + newraise = None + + if engine._has_events: + ctx = ExceptionContextImpl( + e, sqlalchemy_exception, engine, None, None, None, + None, None, is_disconnect) + for fn in engine.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if sqlalchemy_exception and \ + is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = \ + is_disconnect = ctx.is_disconnect + + if newraise: + util.raise_from_cause(newraise, exc_info) + elif should_wrap: + util.raise_from_cause( + sqlalchemy_exception, + exc_info + ) + else: + util.reraise(*exc_info) + def default_schema_name(self): return self.engine.dialect.get_default_schema_name(self) @@ -1320,8 +1475,9 @@ class ExceptionContextImpl(ExceptionContext): """Implement the :class:`.ExceptionContext` interface.""" def __init__(self, exception, sqlalchemy_exception, - connection, cursor, statement, parameters, + engine, connection, cursor, statement, parameters, context, is_disconnect): + self.engine = engine self.connection = connection self.sqlalchemy_exception = sqlalchemy_exception self.original_exception = exception @@ -1865,10 +2021,11 @@ class Engine(Connectable, log.Identified): """ - return self._connection_cls(self, - self.pool.connect(), - close_with_result=close_with_result, - **kwargs) + return self._connection_cls( + self, + self._wrap_pool_connect(self.pool.connect, None), + close_with_result=close_with_result, + **kwargs) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. @@ -1898,7 +2055,18 @@ class Engine(Connectable, log.Identified): """ return self.run_callable(self.dialect.has_table, table_name, schema) - def raw_connection(self): + def _wrap_pool_connect(self, fn, connection): + dialect = self.dialect + try: + return fn() + except dialect.dbapi.Error as e: + if connection is None: + Connection._handle_dbapi_exception_noconnection( + e, dialect, self) + else: + util.reraise(*sys.exc_info()) + + def raw_connection(self, _connection=None): """Return a "raw" DBAPI connection from the connection pool. The returned object is a proxied version of the DBAPI @@ -1909,13 +2077,18 @@ class Engine(Connectable, log.Identified): for real. This method provides direct DBAPI connection access for - special situations. In most situations, the :class:`.Connection` - object should be used, which is procured using the - :meth:`.Engine.connect` method. + special situations when the API provided by :class:`.Connection` + is not needed. When a :class:`.Connection` object is already + present, the DBAPI connection is available using + the :attr:`.Connection.connection` accessor. - """ + .. seealso:: - return self.pool.unique_connection() + :ref:`dbapi_connections` + + """ + return self._wrap_pool_connect( + self.pool.unique_connection, _connection) class OptionEngine(Engine): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a5af6ff19..3eebc6c06 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1,5 +1,5 @@ # engine/default.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -395,6 +395,12 @@ class DefaultDialect(interfaces.Dialect): self._set_connection_isolation(connection, opts['isolation_level']) def _set_connection_isolation(self, connection, level): + if connection.in_transaction(): + util.warn( + "Connection is already established with a Transaction; " + "setting isolation_level may implicitly rollback or commit " + "the existing transaction, or have no effect until " + "next transaction") self.set_isolation_level(connection.connection, level) connection.connection._connection_record.\ finalize_callback.append(self.reset_isolation_level) @@ -452,14 +458,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): isinsert = False isupdate = False isdelete = False + is_crud = False isddl = False executemany = False - result_map = None compiled = None statement = None - postfetch_cols = None - prefetch_cols = None - returning_cols = None + result_column_struct = None _is_implicit_returning = False _is_explicit_returning = False @@ -515,15 +519,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not compiled.can_execute: raise exc.ArgumentError("Not an executable clause") - self.execution_options = compiled.statement._execution_options - if connection._execution_options: - self.execution_options = dict(self.execution_options) - self.execution_options.update(connection._execution_options) - - # compiled clauseelement. process bind params, process table defaults, - # track collections used by ResultProxy to target and process results + self.execution_options = compiled.statement._execution_options.union( + connection._execution_options) - self.result_map = compiled.result_map + self.result_column_struct = ( + compiled._result_columns, compiled._ordered_columns) self.unicode_statement = util.text_type(compiled) if not dialect.supports_unicode_statements: @@ -548,6 +548,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() if self.isinsert or self.isupdate or self.isdelete: + self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( compiled.returning and not compiled.statement._returning) @@ -681,10 +682,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.execution_options.get("no_parameters", False) @util.memoized_property - def is_crud(self): - return self.isinsert or self.isupdate or self.isdelete - - @util.memoized_property def should_autocommit(self): autocommit = self.execution_options.get('autocommit', not self.compiled and @@ -799,52 +796,84 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - def post_insert(self): - + def _setup_crud_result_proxy(self): + if self.isinsert and \ + not self.executemany: + if not self._is_implicit_returning and \ + not self.compiled.inline and \ + self.dialect.postfetch_lastrowid: + + self._setup_ins_pk_from_lastrowid() + + elif not self._is_implicit_returning: + self._setup_ins_pk_from_empty() + + result = self.get_result_proxy() + + if self.isinsert: + if self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + self._setup_ins_pk_from_implicit_returning(row) + result._soft_close(_autoclose_connection=False) + result._metadata = None + elif not self._is_explicit_returning: + result._soft_close(_autoclose_connection=False) + result._metadata = None + elif self.isupdate and self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + result._soft_close(_autoclose_connection=False) + result._metadata = None + + elif result._metadata is None: + # no results, get rowcount + # (which requires open cursor on some drivers + # such as kintersbasdb, mxodbc) + result.rowcount + result._soft_close(_autoclose_connection=False) + return result + + def _setup_ins_pk_from_lastrowid(self): key_getter = self.compiled._key_getters_for_crud_column[2] table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] + + lastrowid = self.get_lastrowid() + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + proc = autoinc_col.type._cached_result_processor( + self.dialect, None) + if proc is not None: + lastrowid = proc(lastrowid) + self.inserted_primary_key = [ + lastrowid if c is autoinc_col else + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] - if not self._is_implicit_returning and \ - not self._is_explicit_returning and \ - not self.compiled.inline and \ - self.dialect.postfetch_lastrowid: - - lastrowid = self.get_lastrowid() - autoinc_col = table._autoincrement_column - if autoinc_col is not None: - # apply type post processors to the lastrowid - proc = autoinc_col.type._cached_result_processor( - self.dialect, None) - if proc is not None: - lastrowid = proc(lastrowid) - self.inserted_primary_key = [ - lastrowid if c is autoinc_col else - self.compiled_parameters[0].get(key_getter(c), None) - for c in table.primary_key - ] - else: - self.inserted_primary_key = [ - self.compiled_parameters[0].get(key_getter(c), None) - for c in table.primary_key - ] - - def _fetch_implicit_returning(self, resultproxy): + def _setup_ins_pk_from_empty(self): + key_getter = self.compiled._key_getters_for_crud_column[2] table = self.compiled.statement.table - row = resultproxy.fetchone() - - ipk = [] - for c, v in zip(table.primary_key, self.inserted_primary_key): - if v is not None: - ipk.append(v) - else: - ipk.append(row[c]) + compiled_params = self.compiled_parameters[0] + self.inserted_primary_key = [ + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] - self.inserted_primary_key = ipk - self.returned_defaults = row + def _setup_ins_pk_from_implicit_returning(self, row): + key_getter = self.compiled._key_getters_for_crud_column[2] + table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] - def _fetch_implicit_update_returning(self, resultproxy): - row = resultproxy.fetchone() - self.returned_defaults = row + self.inserted_primary_key = [ + row[col] if value is None else value + for col, value in [ + (col, compiled_params.get(key_getter(col), None)) + for col in table.primary_key + ] + ] def lastrow_has_defaults(self): return (self.isinsert or self.isupdate) and \ @@ -956,14 +985,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - prefetch = self.compiled.prefetch self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] for c in prefetch: if self.isinsert: - val = self.get_insert_default(c) + if c.default and \ + not c.default.is_sequence and c.default.is_scalar: + val = c.default.arg + else: + val = self.get_insert_default(c) else: val = self.get_update_default(c) @@ -972,6 +1004,4 @@ class DefaultExecutionContext(interfaces.ExecutionContext): del self.current_parameters - - DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 0ad2efae0..da8fa81eb 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -1,5 +1,5 @@ # engine/interfaces.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -654,17 +654,82 @@ class Dialect(object): return None def reset_isolation_level(self, dbapi_conn): - """Given a DBAPI connection, revert its isolation to the default.""" + """Given a DBAPI connection, revert its isolation to the default. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`.Connection` and + :class:`.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`.Connection` isolation level + + :paramref:`.create_engine.isolation_level` - + set per :class:`.Engine` isolation level + + """ raise NotImplementedError() def set_isolation_level(self, dbapi_conn, level): - """Given a DBAPI connection, set its isolation level.""" + """Given a DBAPI connection, set its isolation level. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`.Connection` and + :class:`.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`.Connection` isolation level + + :paramref:`.create_engine.isolation_level` - + set per :class:`.Engine` isolation level + + """ raise NotImplementedError() def get_isolation_level(self, dbapi_conn): - """Given a DBAPI connection, return its isolation level.""" + """Given a DBAPI connection, return its isolation level. + + When working with a :class:`.Connection` object, the corresponding + DBAPI connection may be procured using the + :attr:`.Connection.connection` accessor. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`.Connection` and + :class:`.Engine` isolation level facilities; + these APIs should be preferred for most typical use cases. + + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`.Connection` isolation level + + :paramref:`.create_engine.isolation_level` - + set per :class:`.Engine` isolation level + + + """ raise NotImplementedError() @@ -917,7 +982,23 @@ class ExceptionContext(object): connection = None """The :class:`.Connection` in use during the exception. - This member is always present. + This member is present, except in the case of a failure when + first connecting. + + .. seealso:: + + :attr:`.ExceptionContext.engine` + + + """ + + engine = None + """The :class:`.Engine` in use during the exception. + + This member should always be present, even in the case of a failure + when first connecting. + + .. versionadded:: 1.0.0 """ diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 2a1def86a..59eed51ec 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -1,5 +1,5 @@ # engine/reflection.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -173,7 +173,14 @@ class Inspector(object): passed as ``None``. For special quoting, use :class:`.quoted_name`. :param order_by: Optional, may be the string "foreign_key" to sort - the result on foreign key dependencies. + the result on foreign key dependencies. Does not automatically + resolve cycles, and will raise :class:`.CircularDependencyError` + if cycles exist. + + .. deprecated:: 1.0.0 - see + :meth:`.Inspector.get_sorted_table_and_fkc_names` for a version + of this which resolves foreign key cycles between tables + automatically. .. versionchanged:: 0.8 the "foreign_key" sorting sorts tables in order of dependee to dependent; that is, in creation @@ -183,6 +190,8 @@ class Inspector(object): .. seealso:: + :meth:`.Inspector.get_sorted_table_and_fkc_names` + :attr:`.MetaData.sorted_tables` """ @@ -201,6 +210,64 @@ class Inspector(object): tnames = list(topological.sort(tuples, tnames)) return tnames + def get_sorted_table_and_fkc_names(self, schema=None): + """Return dependency-sorted table and foreign key constraint names in + referred to within a particular schema. + + This will yield 2-tuples of + ``(tablename, [(tname, fkname), (tname, fkname), ...])`` + consisting of table names in CREATE order grouped with the foreign key + constraint names that are not detected as belonging to a cycle. + The final element + will be ``(None, [(tname, fkname), (tname, fkname), ..])`` + which will consist of remaining + foreign key constraint names that would require a separate CREATE + step after-the-fact, based on dependencies between tables. + + .. versionadded:: 1.0.- + + .. seealso:: + + :meth:`.Inspector.get_table_names` + + :func:`.sort_tables_and_constraints` - similar method which works + with an already-given :class:`.MetaData`. + + """ + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names( + self.bind, schema, info_cache=self.info_cache) + else: + tnames = self.engine.table_names(schema) + + tuples = set() + remaining_fkcs = set() + + fknames_for_table = {} + for tname in tnames: + fkeys = self.get_foreign_keys(tname, schema) + fknames_for_table[tname] = set( + [fk['name'] for fk in fkeys] + ) + for fkey in fkeys: + if tname != fkey['referred_table']: + tuples.add((fkey['referred_table'], tname)) + try: + candidate_sort = list(topological.sort(tuples, tnames)) + except exc.CircularDependencyError as err: + for edge in err.edges: + tuples.remove(edge) + remaining_fkcs.update( + (edge[1], fkc) + for fkc in fknames_for_table[edge[1]] + ) + + candidate_sort = list(topological.sort(tuples, tnames)) + return [ + (tname, fknames_for_table[tname].difference(remaining_fkcs)) + for tname in candidate_sort + ] + [(None, list(remaining_fkcs))] + def get_temp_table_names(self): """return a list of temporary table names for the current bind. @@ -394,6 +461,12 @@ class Inspector(object): unique boolean + dialect_options + dict of dialect-specific index options. May not be present + for all dialects. + + .. versionadded:: 1.0.0 + :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -642,6 +715,8 @@ class Inspector(object): columns = index_d['column_names'] unique = index_d['unique'] flavor = index_d.get('type', 'index') + dialect_options = index_d.get('dialect_options', {}) + duplicates = index_d.get('duplicates_constraint') if include_columns and \ not set(columns).issubset(include_columns): @@ -667,7 +742,10 @@ class Inspector(object): else: idx_cols.append(idx_col) - sa_schema.Index(name, *idx_cols, **dict(unique=unique)) + sa_schema.Index( + name, *idx_cols, + **dict(list(dialect_options.items()) + [('unique', unique)]) + ) def _reflect_unique_constraints( self, table_name, schema, table, cols_by_orig_name, diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 3995942ef..6d19cb6d0 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -1,5 +1,5 @@ # engine/result.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -187,86 +187,162 @@ class ResultMetaData(object): context.""" def __init__(self, parent, metadata): - self._processors = processors = [] - - # We do not strictly need to store the processor in the key mapping, - # though it is faster in the Python version (probably because of the - # saved attribute lookup self._processors) - self._keymap = keymap = {} - self.keys = [] context = parent.context dialect = context.dialect typemap = dialect.dbapi_type_map translate_colname = context._translate_colname - self.case_sensitive = dialect.case_sensitive - - # high precedence key values. - primary_keymap = {} - - for i, rec in enumerate(metadata): - colname = rec[0] - coltype = rec[1] - - if dialect.description_encoding: - colname = dialect._description_decoder(colname) + self.case_sensitive = case_sensitive = dialect.case_sensitive + if context.result_column_struct: + result_columns, cols_are_ordered = context.result_column_struct + num_ctx_cols = len(result_columns) + else: + num_ctx_cols = None + + if num_ctx_cols and \ + cols_are_ordered and \ + num_ctx_cols == len(metadata): + # case 1 - SQL expression statement, number of columns + # in result matches number of cols in compiled. This is the + # vast majority case for SQL expression constructs. In this + # case we don't bother trying to parse or match up to + # the colnames in the result description. + raw = [ + ( + idx, + key, + name.lower() if not case_sensitive else name, + context.get_result_processor( + type_, key, metadata[idx][1] + ), + obj, + None + ) for idx, (key, name, obj, type_) + in enumerate(result_columns) + ] + self.keys = [ + elem[1] for elem in result_columns + ] + else: + # case 2 - raw string, or number of columns in result does + # not match number of cols in compiled. The raw string case + # is very common. The latter can happen + # when text() is used with only a partial typemap, or + # in the extremely unlikely cases where the compiled construct + # has a single element with multiple col expressions in it + # (e.g. has commas embedded) or there's some kind of statement + # that is adding extra columns. + # In all these cases we fall back to the "named" approach + # that SQLAlchemy has used up through 0.9. + + if num_ctx_cols: + result_map = self._create_result_map(result_columns) + + raw = [] + self.keys = [] + untranslated = None + for idx, rec in enumerate(metadata): + colname = rec[0] + coltype = rec[1] + + if dialect.description_encoding: + colname = dialect._description_decoder(colname) + + if translate_colname: + colname, untranslated = translate_colname(colname) + + if dialect.requires_name_normalize: + colname = dialect.normalize_name(colname) + + self.keys.append(colname) + if not case_sensitive: + colname = colname.lower() + + if num_ctx_cols: + try: + ctx_rec = result_map[colname] + except KeyError: + mapped_type = typemap.get(coltype, sqltypes.NULLTYPE) + obj = None + else: + obj = ctx_rec[1] + mapped_type = ctx_rec[2] + else: + mapped_type = typemap.get(coltype, sqltypes.NULLTYPE) + obj = None + processor = context.get_result_processor( + mapped_type, colname, coltype) + + raw.append( + (idx, colname, colname, processor, obj, untranslated) + ) + + # keymap indexes by integer index... + self._keymap = dict([ + (elem[0], (elem[3], elem[4], elem[0])) + for elem in raw + ]) + + # processors in key order for certain per-row + # views like __iter__ and slices + self._processors = [elem[3] for elem in raw] + + if num_ctx_cols: + # keymap by primary string... + by_key = dict([ + (elem[2], (elem[3], elem[4], elem[0])) + for elem in raw + ]) + + # if by-primary-string dictionary smaller (or bigger?!) than + # number of columns, assume we have dupes, rewrite + # dupe records with "None" for index which results in + # ambiguous column exception when accessed. + if len(by_key) != num_ctx_cols: + seen = set() + for rec in raw: + key = rec[1] + if key in seen: + by_key[key] = (None, by_key[key][1], None) + seen.add(key) + + # update keymap with secondary "object"-based keys + self._keymap.update([ + (obj_elem, by_key[elem[2]]) + for elem in raw if elem[4] + for obj_elem in elem[4] + ]) + + # update keymap with primary string names taking + # precedence + self._keymap.update(by_key) + else: + self._keymap.update([ + (elem[2], (elem[3], elem[4], elem[0])) + for elem in raw + ]) + # update keymap with "translated" names (sqlite-only thing) if translate_colname: - colname, untranslated = translate_colname(colname) - - if dialect.requires_name_normalize: - colname = dialect.normalize_name(colname) - - if context.result_map: - try: - name, obj, type_ = context.result_map[ - colname if self.case_sensitive else colname.lower()] - except KeyError: - name, obj, type_ = \ - colname, None, typemap.get(coltype, sqltypes.NULLTYPE) + self._keymap.update([ + (elem[5], self._keymap[elem[2]]) + for elem in raw if elem[5] + ]) + + @classmethod + def _create_result_map(cls, result_columns): + d = {} + for elem in result_columns: + key, rec = elem[0], elem[1:] + if key in d: + # conflicting keyname, just double up the list + # of objects. this will cause an "ambiguous name" + # error if an attempt is made by the result set to + # access. + e_name, e_obj, e_type = d[key] + d[key] = e_name, e_obj + rec[1], e_type else: - name, obj, type_ = \ - colname, None, typemap.get(coltype, sqltypes.NULLTYPE) - - processor = context.get_result_processor(type_, colname, coltype) - - processors.append(processor) - rec = (processor, obj, i) - - # indexes as keys. This is only needed for the Python version of - # RowProxy (the C version uses a faster path for integer indexes). - primary_keymap[i] = rec - - # populate primary keymap, looking for conflicts. - if primary_keymap.setdefault( - name if self.case_sensitive - else name.lower(), - rec) is not rec: - # place a record that doesn't have the "index" - this - # is interpreted later as an AmbiguousColumnError, - # but only when actually accessed. Columns - # colliding by name is not a problem if those names - # aren't used; integer access is always - # unambiguous. - primary_keymap[name - if self.case_sensitive - else name.lower()] = rec = (None, obj, None) - - self.keys.append(colname) - if obj: - for o in obj: - keymap[o] = rec - # technically we should be doing this but we - # are saving on callcounts by not doing so. - # if keymap.setdefault(o, rec) is not rec: - # keymap[o] = (None, obj, None) - - if translate_colname and \ - untranslated: - keymap[untranslated] = rec - - # overwrite keymap values with those of the - # high precedence keymap. - keymap.update(primary_keymap) + d[key] = rec + return d @util.pending_deprecation("0.8", "sqlite dialect uses " "_translate_colname() now") @@ -403,11 +479,12 @@ class ResultProxy(object): out_parameters = None _can_close_connection = False _metadata = None + _soft_closed = False + closed = False def __init__(self, context): self.context = context self.dialect = context.dialect - self.closed = False self.cursor = self._saved_cursor = context.cursor self.connection = context.root_connection self._echo = self.connection._echo and \ @@ -544,33 +621,79 @@ class ResultProxy(object): return self._saved_cursor.description - def close(self, _autoclose_connection=True): - """Close this ResultProxy. - - Closes the underlying DBAPI cursor corresponding to the execution. + def _soft_close(self, _autoclose_connection=True): + """Soft close this :class:`.ResultProxy`. - Note that any data cached within this ResultProxy is still available. - For some types of results, this may include buffered rows. - - If this ResultProxy was generated from an implicit execution, - the underlying Connection will also be closed (returns the - underlying DBAPI connection to the connection pool.) + This releases all DBAPI cursor resources, but leaves the + ResultProxy "open" from a semantic perspective, meaning the + fetchXXX() methods will continue to return empty results. This method is called automatically when: * all result rows are exhausted using the fetchXXX() methods. * cursor.description is None. + This method is **not public**, but is documented in order to clarify + the "autoclose" process used. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.ResultProxy.close` + + + """ + if self._soft_closed: + return + self._soft_closed = True + cursor = self.cursor + self.connection._safe_close_cursor(cursor) + if _autoclose_connection and \ + self.connection.should_close_with_result: + self.connection.close() + self.cursor = None + + def close(self): + """Close this ResultProxy. + + This closes out the underlying DBAPI cursor corresonding + to the statement execution, if one is stil present. Note that the + DBAPI cursor is automatically released when the :class:`.ResultProxy` + exhausts all available rows. :meth:`.ResultProxy.close` is generally + an optional method except in the case when discarding a + :class:`.ResultProxy` that still has additional rows pending for fetch. + + In the case of a result that is the product of + :ref:`connectionless execution <dbengine_implicit>`, + the underyling :class:`.Connection` object is also closed, which + :term:`releases` DBAPI connection resources. + + After this method is called, it is no longer valid to call upon + the fetch methods, which will raise a :class:`.ResourceClosedError` + on subsequent use. + + .. versionchanged:: 1.0.0 - the :meth:`.ResultProxy.close` method + has been separated out from the process that releases the underlying + DBAPI cursor resource. The "auto close" feature of the + :class:`.Connection` now performs a so-called "soft close", which + releases the underlying DBAPI cursor, but allows the + :class:`.ResultProxy` to still behave as an open-but-exhausted + result set; the actual :meth:`.ResultProxy.close` method is never + called. It is still safe to discard a :class:`.ResultProxy` + that has been fully exhausted without calling this method. + + .. seealso:: + + :ref:`connections_toplevel` + + :meth:`.ResultProxy._soft_close` + """ if not self.closed: + self._soft_close() self.closed = True - self.connection._safe_close_cursor(self.cursor) - if _autoclose_connection and \ - self.connection.should_close_with_result: - self.connection.close() - # allow consistent errors - self.cursor = None def __iter__(self): while True: @@ -761,7 +884,7 @@ class ResultProxy(object): try: return self.cursor.fetchone() except AttributeError: - self._non_result() + return self._non_result(None) def _fetchmany_impl(self, size=None): try: @@ -770,22 +893,24 @@ class ResultProxy(object): else: return self.cursor.fetchmany(size) except AttributeError: - self._non_result() + return self._non_result([]) def _fetchall_impl(self): try: return self.cursor.fetchall() except AttributeError: - self._non_result() + return self._non_result([]) - def _non_result(self): + def _non_result(self, default): if self._metadata is None: raise exc.ResourceClosedError( "This result object does not return rows. " "It has been closed automatically.", ) - else: + elif self.closed: raise exc.ResourceClosedError("This result object is closed.") + else: + return default def process_rows(self, rows): process_row = self._process_row @@ -804,11 +929,25 @@ class ResultProxy(object): for row in rows] def fetchall(self): - """Fetch all rows, just like DB-API ``cursor.fetchall()``.""" + """Fetch all rows, just like DB-API ``cursor.fetchall()``. + + After all rows have been exhausted, the underlying DBAPI + cursor resource is released, and the object may be safely + discarded. + + Subsequent calls to :meth:`.ResultProxy.fetchall` will return + an empty list. After the :meth:`.ResultProxy.close` method is + called, the method will raise :class:`.ResourceClosedError`. + + .. versionchanged:: 1.0.0 - Added "soft close" behavior which + allows the result to be used in an "exhausted" state prior to + calling the :meth:`.ResultProxy.close` method. + + """ try: l = self.process_rows(self._fetchall_impl()) - self.close() + self._soft_close() return l except Exception as e: self.connection._handle_dbapi_exception( @@ -819,15 +958,25 @@ class ResultProxy(object): """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``. - If rows are present, the cursor remains open after this is called. - Else the cursor is automatically closed and an empty list is returned. + After all rows have been exhausted, the underlying DBAPI + cursor resource is released, and the object may be safely + discarded. + + Calls to :meth:`.ResultProxy.fetchmany` after all rows have been + exhuasted will return + an empty list. After the :meth:`.ResultProxy.close` method is + called, the method will raise :class:`.ResourceClosedError`. + + .. versionchanged:: 1.0.0 - Added "soft close" behavior which + allows the result to be used in an "exhausted" state prior to + calling the :meth:`.ResultProxy.close` method. """ try: l = self.process_rows(self._fetchmany_impl(size)) if len(l) == 0: - self.close() + self._soft_close() return l except Exception as e: self.connection._handle_dbapi_exception( @@ -837,8 +986,18 @@ class ResultProxy(object): def fetchone(self): """Fetch one row, just like DB-API ``cursor.fetchone()``. - If a row is present, the cursor remains open after this is called. - Else the cursor is automatically closed and None is returned. + After all rows have been exhausted, the underlying DBAPI + cursor resource is released, and the object may be safely + discarded. + + Calls to :meth:`.ResultProxy.fetchone` after all rows have + been exhausted will return ``None``. + After the :meth:`.ResultProxy.close` method is + called, the method will raise :class:`.ResourceClosedError`. + + .. versionchanged:: 1.0.0 - Added "soft close" behavior which + allows the result to be used in an "exhausted" state prior to + calling the :meth:`.ResultProxy.close` method. """ try: @@ -846,7 +1005,7 @@ class ResultProxy(object): if row is not None: return self.process_rows([row])[0] else: - self.close() + self._soft_close() return None except Exception as e: self.connection._handle_dbapi_exception( @@ -858,9 +1017,12 @@ class ResultProxy(object): Returns None if no row is present. + After calling this method, the object is fully closed, + e.g. the :meth:`.ResultProxy.close` method will have been called. + """ if self._metadata is None: - self._non_result() + return self._non_result(None) try: row = self._fetchone_impl() @@ -882,6 +1044,9 @@ class ResultProxy(object): Returns None if no row is present. + After calling this method, the object is fully closed, + e.g. the :meth:`.ResultProxy.close` method will have been called. + """ row = self.first() if row is not None: @@ -925,13 +1090,19 @@ class BufferedRowResultProxy(ResultProxy): } def __buffer_rows(self): + if self.cursor is None: + return size = getattr(self, '_bufsize', 1) self.__rowbuffer = collections.deque(self.cursor.fetchmany(size)) self._bufsize = self.size_growth.get(size, size) + def _soft_close(self, **kw): + self.__rowbuffer.clear() + super(BufferedRowResultProxy, self)._soft_close(**kw) + def _fetchone_impl(self): - if self.closed: - return None + if self.cursor is None: + return self._non_result(None) if not self.__rowbuffer: self.__buffer_rows() if not self.__rowbuffer: @@ -950,6 +1121,8 @@ class BufferedRowResultProxy(ResultProxy): return result def _fetchall_impl(self): + if self.cursor is None: + return self._non_result([]) self.__rowbuffer.extend(self.cursor.fetchall()) ret = self.__rowbuffer self.__rowbuffer = collections.deque() @@ -972,11 +1145,15 @@ class FullyBufferedResultProxy(ResultProxy): def _buffer_rows(self): return collections.deque(self.cursor.fetchall()) + def _soft_close(self, **kw): + self.__rowbuffer.clear() + super(FullyBufferedResultProxy, self)._soft_close(**kw) + def _fetchone_impl(self): if self.__rowbuffer: return self.__rowbuffer.popleft() else: - return None + return self._non_result(None) def _fetchmany_impl(self, size=None): if size is None: @@ -990,6 +1167,8 @@ class FullyBufferedResultProxy(ResultProxy): return result def _fetchall_impl(self): + if not self.cursor: + return self._non_result([]) ret = self.__rowbuffer self.__rowbuffer = collections.deque() return ret diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 398ef8df6..1fd105d67 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -1,5 +1,5 @@ # engine/strategies.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -86,16 +86,7 @@ class DefaultEngineStrategy(EngineStrategy): pool = pop_kwarg('pool', None) if pool is None: def connect(): - try: - return dialect.connect(*cargs, **cparams) - except dialect.dbapi.Error as e: - invalidated = dialect.is_disconnect(e, None, None) - util.raise_from_cause( - exc.DBAPIError.instance( - None, None, e, dialect.dbapi.Error, - connection_invalidated=invalidated - ) - ) + return dialect.connect(*cargs, **cparams) creator = pop_kwarg('creator', connect) diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 637523a0e..0d6e1c0f1 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -1,5 +1,5 @@ # engine/threadlocal.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -59,7 +59,10 @@ class TLEngine(base.Engine): # guards against pool-level reapers, if desired. # or not connection.connection.is_valid: connection = self._tl_connection_cls( - self, self.pool.connect(), **kw) + self, + self._wrap_pool_connect( + self.pool.connect, connection), + **kw) self._connections.conn = weakref.ref(connection) return connection._increment_connect() diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 6544cfbf3..d045961dd 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -1,5 +1,5 @@ # engine/url.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index d9eb1df10..3734c9960 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -1,5 +1,5 @@ # engine/util.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index b93c0ef85..c9bdb9a0e 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -1,5 +1,5 @@ # event/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py index b3d79bcf4..86ef094d6 100644 --- a/lib/sqlalchemy/event/api.py +++ b/lib/sqlalchemy/event/api.py @@ -1,5 +1,5 @@ # event/api.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index be2a82208..a64c7d08d 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -1,5 +1,5 @@ # event/attr.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -40,17 +40,21 @@ import weakref import collections -class RefCollection(object): - @util.memoized_property - def ref(self): +class RefCollection(util.MemoizedSlots): + __slots__ = 'ref', + + def _memoized_attr_ref(self): return weakref.ref(self, registry._collection_gced) -class _DispatchDescriptor(RefCollection): - """Class-level attributes on :class:`._Dispatch` classes.""" +class _ClsLevelDispatch(RefCollection): + """Class-level events on :class:`._Dispatch` classes.""" + + __slots__ = ('name', 'arg_names', 'has_kw', + 'legacy_signatures', '_clslevel') def __init__(self, parent_dispatch_cls, fn): - self.__name__ = fn.__name__ + self.name = fn.__name__ argspec = util.inspect_getargspec(fn) self.arg_names = argspec.args[1:] self.has_kw = bool(argspec.keywords) @@ -60,11 +64,9 @@ class _DispatchDescriptor(RefCollection): key=lambda s: s[0] ) )) - self.__doc__ = fn.__doc__ = legacy._augment_fn_docs( - self, parent_dispatch_cls, fn) + fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn) self._clslevel = weakref.WeakKeyDictionary() - self._empty_listeners = weakref.WeakKeyDictionary() def _adjust_fn_spec(self, fn, named): if named: @@ -152,34 +154,23 @@ class _DispatchDescriptor(RefCollection): def for_modify(self, obj): """Return an event collection which can be modified. - For _DispatchDescriptor at the class level of + For _ClsLevelDispatch at the class level of a dispatcher, this returns self. """ return self - def __get__(self, obj, cls): - if obj is None: - return self - elif obj._parent_cls in self._empty_listeners: - ret = self._empty_listeners[obj._parent_cls] - else: - self._empty_listeners[obj._parent_cls] = ret = \ - _EmptyListener(self, obj._parent_cls) - # assigning it to __dict__ means - # memoized for fast re-access. but more memory. - obj.__dict__[self.__name__] = ret - return ret +class _InstanceLevelDispatch(RefCollection): + __slots__ = () -class _HasParentDispatchDescriptor(object): def _adjust_fn_spec(self, fn, named): return self.parent._adjust_fn_spec(fn, named) -class _EmptyListener(_HasParentDispatchDescriptor): - """Serves as a class-level interface to the events - served by a _DispatchDescriptor, when there are no +class _EmptyListener(_InstanceLevelDispatch): + """Serves as a proxy interface to the events + served by a _ClsLevelDispatch, when there are no instance-level events present. Is replaced by _ListenerCollection when instance-level @@ -187,14 +178,17 @@ class _EmptyListener(_HasParentDispatchDescriptor): """ + propagate = frozenset() + listeners = () + + __slots__ = 'parent', 'parent_listeners', 'name' + def __init__(self, parent, target_cls): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) - self.parent = parent # _DispatchDescriptor + self.parent = parent # _ClsLevelDispatch self.parent_listeners = parent._clslevel[target_cls] - self.name = parent.__name__ - self.propagate = frozenset() - self.listeners = () + self.name = parent.name def for_modify(self, obj): """Return an event collection which can be modified. @@ -205,9 +199,11 @@ class _EmptyListener(_HasParentDispatchDescriptor): and returns it. """ - result = _ListenerCollection(self.parent, obj._parent_cls) - if obj.__dict__[self.name] is self: - obj.__dict__[self.name] = result + result = _ListenerCollection(self.parent, obj._instance_cls) + if getattr(obj, self.name) is self: + setattr(obj, self.name, result) + else: + assert isinstance(getattr(obj, self.name), _JoinedListener) return result def _needs_modify(self, *args, **kw): @@ -233,11 +229,12 @@ class _EmptyListener(_HasParentDispatchDescriptor): __nonzero__ = __bool__ -class _CompoundListener(_HasParentDispatchDescriptor): +class _CompoundListener(_InstanceLevelDispatch): _exec_once = False - @util.memoized_property - def _exec_once_mutex(self): + __slots__ = '_exec_once_mutex', + + def _memoized_attr__exec_once_mutex(self): return threading.Lock() def exec_once(self, *args, **kw): @@ -272,7 +269,7 @@ class _CompoundListener(_HasParentDispatchDescriptor): __nonzero__ = __bool__ -class _ListenerCollection(RefCollection, _CompoundListener): +class _ListenerCollection(_CompoundListener): """Instance-level attributes on instances of :class:`._Dispatch`. Represents a collection of listeners. @@ -282,12 +279,14 @@ class _ListenerCollection(RefCollection, _CompoundListener): """ + __slots__ = 'parent_listeners', 'parent', 'name', 'listeners', 'propagate' + def __init__(self, parent, target_cls): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) self.parent_listeners = parent._clslevel[target_cls] self.parent = parent - self.name = parent.__name__ + self.name = parent.name self.listeners = collections.deque() self.propagate = set() @@ -339,24 +338,11 @@ class _ListenerCollection(RefCollection, _CompoundListener): self.listeners.clear() -class _JoinedDispatchDescriptor(object): - def __init__(self, name): - self.name = name - - def __get__(self, obj, cls): - if obj is None: - return self - else: - obj.__dict__[self.name] = ret = _JoinedListener( - obj.parent, self.name, - getattr(obj.local, self.name) - ) - return ret - - class _JoinedListener(_CompoundListener): _exec_once = False + __slots__ = 'parent', 'name', 'local', 'parent_listeners' + def __init__(self, parent, name, local): self.parent = parent self.name = name diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 4925f6ffa..1fe83eea2 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -1,5 +1,5 @@ # event/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -17,9 +17,11 @@ instances of ``_Dispatch``. """ from __future__ import absolute_import +import weakref + from .. import util -from .attr import _JoinedDispatchDescriptor, \ - _EmptyListener, _DispatchDescriptor +from .attr import _JoinedListener, \ + _EmptyListener, _ClsLevelDispatch _registrars = util.defaultdict(list) @@ -34,10 +36,11 @@ class _UnpickleDispatch(object): """ - def __call__(self, _parent_cls): - for cls in _parent_cls.__mro__: + def __call__(self, _instance_cls): + for cls in _instance_cls.__mro__: if 'dispatch' in cls.__dict__: - return cls.__dict__['dispatch'].dispatch_cls(_parent_cls) + return cls.__dict__['dispatch'].\ + dispatch_cls._for_class(_instance_cls) else: raise AttributeError("No class with a 'dispatch' member present.") @@ -62,16 +65,53 @@ class _Dispatch(object): """ - _events = None - """reference the :class:`.Events` class which this - :class:`._Dispatch` is created for.""" + # in one ORM edge case, an attribute is added to _Dispatch, + # so __dict__ is used in just that case and potentially others. + __slots__ = '_parent', '_instance_cls', '__dict__', '_empty_listeners' + + _empty_listener_reg = weakref.WeakKeyDictionary() + + def __init__(self, parent, instance_cls=None): + self._parent = parent + self._instance_cls = instance_cls + if instance_cls: + try: + self._empty_listeners = self._empty_listener_reg[instance_cls] + except KeyError: + self._empty_listeners = \ + self._empty_listener_reg[instance_cls] = dict( + (ls.name, _EmptyListener(ls, instance_cls)) + for ls in parent._event_descriptors + ) + else: + self._empty_listeners = {} + + def __getattr__(self, name): + # assign EmptyListeners as attributes on demand + # to reduce startup time for new dispatch objects + try: + ls = self._empty_listeners[name] + except KeyError: + raise AttributeError(name) + else: + setattr(self, ls.name, ls) + return ls + + @property + def _event_descriptors(self): + for k in self._event_names: + yield getattr(self, k) + + def _for_class(self, instance_cls): + return self.__class__(self, instance_cls) - def __init__(self, _parent_cls): - self._parent_cls = _parent_cls + def _for_instance(self, instance): + instance_cls = instance.__class__ + return self._for_class(instance_cls) - @util.classproperty - def _listen(cls): - return cls._events._listen + @property + def _listen(self): + return self._events._listen def _join(self, other): """Create a 'join' of this :class:`._Dispatch` and another. @@ -83,36 +123,27 @@ class _Dispatch(object): if '_joined_dispatch_cls' not in self.__class__.__dict__: cls = type( "Joined%s" % self.__class__.__name__, - (_JoinedDispatcher, self.__class__), {} + (_JoinedDispatcher, ), {'__slots__': self._event_names} ) - for ls in _event_descriptors(self): - setattr(cls, ls.name, _JoinedDispatchDescriptor(ls.name)) self.__class__._joined_dispatch_cls = cls return self._joined_dispatch_cls(self, other) def __reduce__(self): - return _UnpickleDispatch(), (self._parent_cls, ) + return _UnpickleDispatch(), (self._instance_cls, ) def _update(self, other, only_propagate=True): """Populate from the listeners in another :class:`_Dispatch` object.""" - - for ls in _event_descriptors(other): + for ls in other._event_descriptors: if isinstance(ls, _EmptyListener): continue getattr(self, ls.name).\ for_modify(self)._update(ls, only_propagate=only_propagate) - @util.hybridmethod def _clear(self): - for attr in dir(self): - if _is_event_name(attr): - getattr(self, attr).for_modify(self).clear() - - -def _event_descriptors(target): - return [getattr(target, k) for k in dir(target) if _is_event_name(k)] + for ls in self._event_descriptors: + ls.for_modify(self).clear() class _EventMeta(type): @@ -131,26 +162,37 @@ def _create_dispatcher_class(cls, classname, bases, dict_): # there's all kinds of ways to do this, # i.e. make a Dispatch class that shares the '_listen' method # of the Event class, this is the straight monkeypatch. - dispatch_base = getattr(cls, 'dispatch', _Dispatch) + if hasattr(cls, 'dispatch'): + dispatch_base = cls.dispatch.__class__ + else: + dispatch_base = _Dispatch + + event_names = [k for k in dict_ if _is_event_name(k)] dispatch_cls = type("%sDispatch" % classname, - (dispatch_base, ), {}) - cls._set_dispatch(cls, dispatch_cls) + (dispatch_base, ), {'__slots__': event_names}) + + dispatch_cls._event_names = event_names - for k in dict_: - if _is_event_name(k): - setattr(dispatch_cls, k, _DispatchDescriptor(cls, dict_[k])) - _registrars[k].append(cls) + dispatch_inst = cls._set_dispatch(cls, dispatch_cls) + for k in dispatch_cls._event_names: + setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) + _registrars[k].append(cls) + + for super_ in dispatch_cls.__bases__: + if issubclass(super_, _Dispatch) and super_ is not _Dispatch: + for ls in super_._events.dispatch._event_descriptors: + setattr(dispatch_inst, ls.name, ls) + dispatch_cls._event_names.append(ls.name) if getattr(cls, '_dispatch_target', None): cls._dispatch_target.dispatch = dispatcher(cls) def _remove_dispatcher(cls): - for k in dir(cls): - if _is_event_name(k): - _registrars[k].remove(cls) - if not _registrars[k]: - del _registrars[k] + for k in cls.dispatch._event_names: + _registrars[k].remove(cls) + if not _registrars[k]: + del _registrars[k] class Events(util.with_metaclass(_EventMeta, object)): @@ -163,17 +205,30 @@ class Events(util.with_metaclass(_EventMeta, object)): # "self.dispatch._events.<utilitymethod>" # @staticemethod to allow easy "super" calls while in a metaclass # constructor. - cls.dispatch = dispatch_cls + cls.dispatch = dispatch_cls(None) dispatch_cls._events = cls + return cls.dispatch @classmethod def _accept_with(cls, target): # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. if hasattr(target, 'dispatch') and ( - isinstance(target.dispatch, cls.dispatch) or - isinstance(target.dispatch, type) and - issubclass(target.dispatch, cls.dispatch) + + isinstance(target.dispatch, cls.dispatch.__class__) or + + + ( + isinstance(target.dispatch, type) and + isinstance(target.dispatch, cls.dispatch.__class__) + ) or + + ( + isinstance(target.dispatch, _JoinedDispatcher) and + isinstance(target.dispatch.parent, cls.dispatch.__class__) + ) + + ): return target else: @@ -195,10 +250,24 @@ class Events(util.with_metaclass(_EventMeta, object)): class _JoinedDispatcher(object): """Represent a connection between two _Dispatch objects.""" + __slots__ = 'local', 'parent', '_instance_cls' + def __init__(self, local, parent): self.local = local self.parent = parent - self._parent_cls = local._parent_cls + self._instance_cls = self.local._instance_cls + + def __getattr__(self, name): + # assign _JoinedListeners as attributes on demand + # to reduce startup time for new dispatch objects + ls = getattr(self.local, name) + jl = _JoinedListener(self.parent, ls.name, ls) + setattr(self, ls.name, jl) + return jl + + @property + def _listen(self): + return self.parent._listen class dispatcher(object): @@ -216,5 +285,5 @@ class dispatcher(object): def __get__(self, obj, cls): if obj is None: return self.dispatch_cls - obj.__dict__['dispatch'] = disp = self.dispatch_cls(cls) + obj.__dict__['dispatch'] = disp = self.dispatch_cls._for_instance(obj) return disp diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 3b1519cb6..daa74226f 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -1,5 +1,5 @@ # event/legacy.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -22,8 +22,8 @@ def _legacy_signature(since, argnames, converter=None): return leg -def _wrap_fn_for_legacy(dispatch_descriptor, fn, argspec): - for since, argnames, conv in dispatch_descriptor.legacy_signatures: +def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): + for since, argnames, conv in dispatch_collection.legacy_signatures: if argnames[-1] == "**kw": has_kw = True argnames = argnames[0:-1] @@ -40,7 +40,7 @@ def _wrap_fn_for_legacy(dispatch_descriptor, fn, argspec): return fn(*conv(*args)) else: def wrap_leg(*args, **kw): - argdict = dict(zip(dispatch_descriptor.arg_names, args)) + argdict = dict(zip(dispatch_collection.arg_names, args)) args = [argdict[name] for name in argnames] if has_kw: return fn(*args, **kw) @@ -58,16 +58,16 @@ def _indent(text, indent): ) -def _standard_listen_example(dispatch_descriptor, sample_target, fn): +def _standard_listen_example(dispatch_collection, sample_target, fn): example_kw_arg = _indent( "\n".join( "%(arg)s = kw['%(arg)s']" % {"arg": arg} - for arg in dispatch_descriptor.arg_names[0:2] + for arg in dispatch_collection.arg_names[0:2] ), " ") - if dispatch_descriptor.legacy_signatures: + if dispatch_collection.legacy_signatures: current_since = max(since for since, args, conv - in dispatch_descriptor.legacy_signatures) + in dispatch_collection.legacy_signatures) else: current_since = None text = ( @@ -80,7 +80,7 @@ def _standard_listen_example(dispatch_descriptor, sample_target, fn): "\n # ... (event handling logic) ...\n" ) - if len(dispatch_descriptor.arg_names) > 3: + if len(dispatch_collection.arg_names) > 3: text += ( "\n# named argument style (new in 0.9)\n" @@ -96,17 +96,17 @@ def _standard_listen_example(dispatch_descriptor, sample_target, fn): "current_since": " (arguments as of %s)" % current_since if current_since else "", "event_name": fn.__name__, - "has_kw_arguments": ", **kw" if dispatch_descriptor.has_kw else "", - "named_event_arguments": ", ".join(dispatch_descriptor.arg_names), + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", + "named_event_arguments": ", ".join(dispatch_collection.arg_names), "example_kw_arg": example_kw_arg, "sample_target": sample_target } return text -def _legacy_listen_examples(dispatch_descriptor, sample_target, fn): +def _legacy_listen_examples(dispatch_collection, sample_target, fn): text = "" - for since, args, conv in dispatch_descriptor.legacy_signatures: + for since, args, conv in dispatch_collection.legacy_signatures: text += ( "\n# legacy calling style (pre-%(since)s)\n" "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" @@ -117,7 +117,7 @@ def _legacy_listen_examples(dispatch_descriptor, sample_target, fn): "since": since, "event_name": fn.__name__, "has_kw_arguments": " **kw" - if dispatch_descriptor.has_kw else "", + if dispatch_collection.has_kw else "", "named_event_arguments": ", ".join(args), "sample_target": sample_target } @@ -125,8 +125,8 @@ def _legacy_listen_examples(dispatch_descriptor, sample_target, fn): return text -def _version_signature_changes(dispatch_descriptor): - since, args, conv = dispatch_descriptor.legacy_signatures[0] +def _version_signature_changes(dispatch_collection): + since, args, conv = dispatch_collection.legacy_signatures[0] return ( "\n.. versionchanged:: %(since)s\n" " The ``%(event_name)s`` event now accepts the \n" @@ -135,14 +135,14 @@ def _version_signature_changes(dispatch_descriptor): " signature(s) listed above will be automatically \n" " adapted to the new signature." % { "since": since, - "event_name": dispatch_descriptor.__name__, - "named_event_arguments": ", ".join(dispatch_descriptor.arg_names), - "has_kw_arguments": ", **kw" if dispatch_descriptor.has_kw else "" + "event_name": dispatch_collection.name, + "named_event_arguments": ", ".join(dispatch_collection.arg_names), + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "" } ) -def _augment_fn_docs(dispatch_descriptor, parent_dispatch_cls, fn): +def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn): header = ".. container:: event_signatures\n\n"\ " Example argument forms::\n"\ "\n" @@ -152,16 +152,16 @@ def _augment_fn_docs(dispatch_descriptor, parent_dispatch_cls, fn): header + _indent( _standard_listen_example( - dispatch_descriptor, sample_target, fn), + dispatch_collection, sample_target, fn), " " * 8) ) - if dispatch_descriptor.legacy_signatures: + if dispatch_collection.legacy_signatures: text += _indent( _legacy_listen_examples( - dispatch_descriptor, sample_target, fn), + dispatch_collection, sample_target, fn), " " * 8) - text += _version_signature_changes(dispatch_descriptor) + text += _version_signature_changes(dispatch_collection) return util.inject_docstring_text(fn.__doc__, text, diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index 5b422c401..a6eabb2ff 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -1,5 +1,5 @@ # event/registry.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -37,7 +37,7 @@ listener collections and the listener fn contained _collection_to_key = collections.defaultdict(dict) """ -Given a _ListenerCollection or _DispatchDescriptor, can locate +Given a _ListenerCollection or _ClsLevelListener, can locate all the original listen() arguments and the listener fn contained ref(listenercollection) -> { @@ -140,6 +140,10 @@ class _EventKey(object): """Represent :func:`.listen` arguments. """ + __slots__ = ( + 'target', 'identifier', 'fn', 'fn_key', 'fn_wrap', 'dispatch_target' + ) + def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None): self.target = target @@ -187,9 +191,9 @@ class _EventKey(object): target, identifier, fn = \ self.dispatch_target, self.identifier, self._listen_fn - dispatch_descriptor = getattr(target.dispatch, identifier) + dispatch_collection = getattr(target.dispatch, identifier) - adjusted_fn = dispatch_descriptor._adjust_fn_spec(fn, named) + adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named) self = self.with_wrapper(adjusted_fn) @@ -226,13 +230,13 @@ class _EventKey(object): target, identifier, fn = \ self.dispatch_target, self.identifier, self._listen_fn - dispatch_descriptor = getattr(target.dispatch, identifier) + dispatch_collection = getattr(target.dispatch, identifier) if insert: - dispatch_descriptor.\ + dispatch_collection.\ for_modify(target.dispatch).insert(self, propagate) else: - dispatch_descriptor.\ + dispatch_collection.\ for_modify(target.dispatch).append(self, propagate) @property diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 1ff35b8b0..22e066c88 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -1,5 +1,5 @@ # sqlalchemy/events.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -338,7 +338,7 @@ class PoolEvents(event.Events): """ - def reset(self, dbapi_connnection, connection_record): + def reset(self, dbapi_connection, connection_record): """Called before the "reset" action occurs for a pooled connection. This event represents @@ -420,6 +420,12 @@ class ConnectionEvents(event.Events): context, executemany): log.info("Received statement: %s" % statement) + When the methods are called with a `statement` parameter, such as in + :meth:`.after_cursor_execute`, :meth:`.before_cursor_execute` and + :meth:`.dbapi_error`, the statement is the exact SQL string that was + prepared for transmission to the DBAPI ``cursor`` in the connection's + :class:`.Dialect`. + The :meth:`.before_execute` and :meth:`.before_cursor_execute` events can also be established with the ``retval=True`` flag, which allows modification of the statement and parameters to be sent @@ -549,9 +555,8 @@ class ConnectionEvents(event.Events): def before_cursor_execute(self, conn, cursor, statement, parameters, context, executemany): """Intercept low-level cursor execute() events before execution, - receiving the string - SQL statement and DBAPI-specific parameter list to be invoked - against a cursor. + receiving the string SQL statement and DBAPI-specific parameter list to + be invoked against a cursor. This event is a good choice for logging as well as late modifications to the SQL string. It's less ideal for parameter modifications except @@ -571,7 +576,7 @@ class ConnectionEvents(event.Events): :param conn: :class:`.Connection` object :param cursor: DBAPI cursor object - :param statement: string SQL statement + :param statement: string SQL statement, as to be passed to the DBAPI :param parameters: Dictionary, tuple, or list of parameters being passed to the ``execute()`` or ``executemany()`` method of the DBAPI ``cursor``. In some cases may be ``None``. @@ -596,7 +601,7 @@ class ConnectionEvents(event.Events): :param cursor: DBAPI cursor object. Will have results pending if the statement was a SELECT, but these should not be consumed as they will be needed by the :class:`.ResultProxy`. - :param statement: string SQL statement + :param statement: string SQL statement, as passed to the DBAPI :param parameters: Dictionary, tuple, or list of parameters being passed to the ``execute()`` or ``executemany()`` method of the DBAPI ``cursor``. In some cases may be ``None``. @@ -640,7 +645,7 @@ class ConnectionEvents(event.Events): :param conn: :class:`.Connection` object :param cursor: DBAPI cursor object - :param statement: string SQL statement + :param statement: string SQL statement, as passed to the DBAPI :param parameters: Dictionary, tuple, or list of parameters being passed to the ``execute()`` or ``executemany()`` method of the DBAPI ``cursor``. In some cases may be ``None``. @@ -734,6 +739,12 @@ class ConnectionEvents(event.Events): .. versionadded:: 0.9.7 Added the :meth:`.ConnectionEvents.handle_error` hook. + .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is now + invoked when an :class:`.Engine` fails during the initial + call to :meth:`.Engine.connect`, as well as when a + :class:`.Connection` object encounters an error during a + reconnect operation. + .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is not fired off when a dialect makes use of the ``skip_user_error_events`` execution option. This is used diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 5d35dc2e7..9b27436b3 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -1,5 +1,5 @@ # sqlalchemy/exc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -54,8 +54,7 @@ class CircularDependencyError(SQLAlchemyError): or pre-deassociate one of the foreign key constrained values. The ``post_update`` flag described at :ref:`post_update` can resolve this cycle. - * In a :meth:`.MetaData.create_all`, :meth:`.MetaData.drop_all`, - :attr:`.MetaData.sorted_tables` operation, two :class:`.ForeignKey` + * In a :attr:`.MetaData.sorted_tables` operation, two :class:`.ForeignKey` or :class:`.ForeignKeyConstraint` objects mutually refer to each other. Apply the ``use_alter=True`` flag to one or both, see :ref:`use_alter`. @@ -63,7 +62,7 @@ class CircularDependencyError(SQLAlchemyError): """ def __init__(self, message, cycles, edges, msg=None): if msg is None: - message += " Cycles: %r all edges: %r" % (cycles, edges) + message += " (%s)" % ", ".join(repr(s) for s in cycles) else: message = msg SQLAlchemyError.__init__(self, message) @@ -238,14 +237,16 @@ class StatementError(SQLAlchemyError): def __str__(self): from sqlalchemy.sql import util - params_repr = util._repr_params(self.params, 10) + details = [SQLAlchemyError.__str__(self)] + if self.statement: + details.append("[SQL: %r]" % self.statement) + if self.params: + params_repr = util._repr_params(self.params, 10) + details.append("[parameters: %r]" % params_repr) return ' '.join([ "(%s)" % det for det in self.detail - ] + [ - SQLAlchemyError.__str__(self), - repr(self.statement), repr(params_repr) - ]) + ] + details) def __unicode__(self): return self.__str__() @@ -289,10 +290,10 @@ class DBAPIError(StatementError): # not a DBAPI error, statement is present. # raise a StatementError if not isinstance(orig, dbapi_base_err) and statement: - msg = traceback.format_exception_only( - orig.__class__, orig)[-1].strip() return StatementError( - "%s (original cause: %s)" % (str(orig), msg), + "(%s.%s) %s" % + (orig.__class__.__module__, orig.__class__.__name__, + orig), statement, params, orig ) @@ -316,7 +317,8 @@ class DBAPIError(StatementError): text = 'Error in str() of DB-API-generated exception: ' + str(e) StatementError.__init__( self, - '(%s) %s' % (orig.__class__.__name__, text), + '(%s.%s) %s' % ( + orig.__class__.__module__, orig.__class__.__name__, text, ), statement, params, orig diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py index d213a0d30..60a17c65e 100644 --- a/lib/sqlalchemy/ext/__init__.py +++ b/lib/sqlalchemy/ext/__init__.py @@ -1,6 +1,11 @@ # ext/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .. import util as _sa_util + +_sa_util.dependencies.resolve_all("sqlalchemy.ext") + diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 1aa68ac32..a74141973 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -1,5 +1,5 @@ # ext/associationproxy.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -86,7 +86,7 @@ ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') """ -class AssociationProxy(interfaces.InspectionAttr): +class AssociationProxy(interfaces.InspectionAttrInfo): """A descriptor that presents a read/write view of an object attribute.""" is_attribute = False @@ -527,7 +527,10 @@ class _AssociationList(_AssociationCollection): return self.setter(object, value) def __getitem__(self, index): - return self._get(self.col[index]) + if not isinstance(index, slice): + return self._get(self.col[index]) + else: + return [self._get(member) for member in self.col[index]] def __setitem__(self, index, value): if not isinstance(index, slice): diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index c11795d37..ca550ded6 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -1,5 +1,5 @@ # ext/automap.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py new file mode 100644 index 000000000..65d6a8603 --- /dev/null +++ b/lib/sqlalchemy/ext/baked.py @@ -0,0 +1,499 @@ +# sqlalchemy/ext/baked.py +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Baked query extension. + +Provides a creational pattern for the :class:`.query.Query` object which +allows the fully constructed object, Core select statement, and string +compiled result to be fully cached. + + +""" + +from ..orm.query import Query +from ..orm import strategies, attributes, properties, \ + strategy_options, util as orm_util, interfaces +from .. import log as sqla_log +from ..sql import util as sql_util +from ..orm import exc as orm_exc +from .. import exc as sa_exc +from .. import util + +import copy +import logging + +log = logging.getLogger(__name__) + + +class BakedQuery(object): + """A builder object for :class:`.query.Query` objects.""" + + __slots__ = 'steps', '_bakery', '_cache_key', '_spoiled' + + def __init__(self, bakery, initial_fn, args=()): + if args: + self._cache_key = tuple(args) + else: + self._cache_key = () + self._update_cache_key(initial_fn) + self.steps = [initial_fn] + self._spoiled = False + self._bakery = bakery + + @classmethod + def bakery(cls, size=200): + """Construct a new bakery.""" + + _bakery = util.LRUCache(size) + + def call(initial_fn): + return cls(_bakery, initial_fn) + + return call + + def _clone(self): + b1 = BakedQuery.__new__(BakedQuery) + b1._cache_key = self._cache_key + b1.steps = list(self.steps) + b1._bakery = self._bakery + b1._spoiled = self._spoiled + return b1 + + def _update_cache_key(self, fn, args=()): + self._cache_key += (fn.__code__,) + args + + def __iadd__(self, other): + if isinstance(other, tuple): + self.add_criteria(*other) + else: + self.add_criteria(other) + return self + + def __add__(self, other): + if isinstance(other, tuple): + return self.with_criteria(*other) + else: + return self.with_criteria(other) + + def add_criteria(self, fn, *args): + """Add a criteria function to this :class:`.BakedQuery`. + + This is equivalent to using the ``+=`` operator to + modify a :class:`.BakedQuery` in-place. + + """ + self._update_cache_key(fn, args) + self.steps.append(fn) + return self + + def with_criteria(self, fn, *args): + """Add a criteria function to a :class:`.BakedQuery` cloned from this one. + + This is equivalent to using the ``+`` operator to + produce a new :class:`.BakedQuery` with modifications. + + """ + return self._clone().add_criteria(fn, *args) + + def for_session(self, session): + """Return a :class:`.Result` object for this :class:`.BakedQuery`. + + This is equivalent to calling the :class:`.BakedQuery` as a + Python callable, e.g. ``result = my_baked_query(session)``. + + """ + return Result(self, session) + + def __call__(self, session): + return self.for_session(session) + + def spoil(self, full=False): + """Cancel any query caching that will occur on this BakedQuery object. + + The BakedQuery can continue to be used normally, however additional + creational functions will not be cached; they will be called + on every invocation. + + This is to support the case where a particular step in constructing + a baked query disqualifies the query from being cacheable, such + as a variant that relies upon some uncacheable value. + + :param full: if False, only functions added to this + :class:`.BakedQuery` object subsequent to the spoil step will be + non-cached; the state of the :class:`.BakedQuery` up until + this point will be pulled from the cache. If True, then the + entire :class:`.Query` object is built from scratch each + time, with all creational functions being called on each + invocation. + + """ + if not full: + _spoil_point = self._clone() + _spoil_point._cache_key += ('_query_only', ) + self.steps = [_spoil_point._retrieve_baked_query] + self._spoiled = True + return self + + def _retrieve_baked_query(self, session): + query = self._bakery.get(self._cache_key, None) + if query is None: + query = self._as_query(session) + self._bakery[self._cache_key] = query.with_session(None) + return query.with_session(session) + + def _bake(self, session): + query = self._as_query(session) + + context = query._compile_context() + self._bake_subquery_loaders(session, context) + context.session = None + context.query = query = context.query.with_session(None) + query._execution_options = query._execution_options.union( + {"compiled_cache": self._bakery} + ) + # we'll be holding onto the query for some of its state, + # so delete some compilation-use-only attributes that can take up + # space + for attr in ( + '_correlate', '_from_obj', '_mapper_adapter_map', + '_joinpath', '_joinpoint'): + query.__dict__.pop(attr, None) + self._bakery[self._cache_key] = context + return context + + def _as_query(self, session): + query = self.steps[0](session) + + for step in self.steps[1:]: + query = step(query) + return query + + def _bake_subquery_loaders(self, session, context): + """convert subquery eager loaders in the cache into baked queries. + + For subquery eager loading to work, all we need here is that the + Query point to the correct session when it is run. However, since + we are "baking" anyway, we may as well also turn the query into + a "baked" query so that we save on performance too. + + """ + context.attributes['baked_queries'] = baked_queries = [] + for k, v in list(context.attributes.items()): + if isinstance(v, Query): + if 'subquery' in k: + bk = BakedQuery(self._bakery, lambda *args: v) + bk._cache_key = self._cache_key + k + bk._bake(session) + baked_queries.append((k, bk._cache_key, v)) + del context.attributes[k] + + def _unbake_subquery_loaders(self, session, context, params): + """Retrieve subquery eager loaders stored by _bake_subquery_loaders + and turn them back into Result objects that will iterate just + like a Query object. + + """ + for k, cache_key, query in context.attributes["baked_queries"]: + bk = BakedQuery(self._bakery, lambda sess: query.with_session(sess)) + bk._cache_key = cache_key + context.attributes[k] = bk.for_session(session).params(**params) + + +class Result(object): + """Invokes a :class:`.BakedQuery` against a :class:`.Session`. + + The :class:`.Result` object is where the actual :class:`.query.Query` + object gets created, or retrieved from the cache, + against a target :class:`.Session`, and is then invoked for results. + + """ + __slots__ = 'bq', 'session', '_params' + + def __init__(self, bq, session): + self.bq = bq + self.session = session + self._params = {} + + def params(self, *args, **kw): + """Specify parameters to be replaced into the string SQL statement.""" + + if len(args) == 1: + kw.update(args[0]) + elif len(args) > 0: + raise sa_exc.ArgumentError( + "params() takes zero or one positional argument, " + "which is a dictionary.") + self._params.update(kw) + return self + + def _as_query(self): + return self.bq._as_query(self.session).params(self._params) + + def __str__(self): + return str(self._as_query()) + + def __iter__(self): + bq = self.bq + if bq._spoiled: + return iter(self._as_query()) + + baked_context = bq._bakery.get(bq._cache_key, None) + if baked_context is None: + baked_context = bq._bake(self.session) + + context = copy.copy(baked_context) + context.session = self.session + context.attributes = context.attributes.copy() + + bq._unbake_subquery_loaders(self.session, context, self._params) + + context.statement.use_labels = True + if context.autoflush and not context.populate_existing: + self.session._autoflush() + return context.query.params(self._params).\ + with_session(self.session)._execute_and_instances(context) + + def first(self): + """Return the first row. + + Equivalent to :meth:`.Query.first`. + + """ + bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) + ret = list(bq.for_session(self.session).params(self._params)) + if len(ret) > 0: + return ret[0] + else: + return None + + def one(self): + """Return exactly one result or raise an exception. + + Equivalent to :meth:`.Query.one`. + + """ + ret = list(self) + + l = len(ret) + if l == 1: + return ret[0] + elif l == 0: + raise orm_exc.NoResultFound("No row was found for one()") + else: + raise orm_exc.MultipleResultsFound( + "Multiple rows were found for one()") + + def all(self): + """Return all rows. + + Equivalent to :meth:`.Query.all`. + + """ + return list(self) + + def get(self, ident): + """Retrieve an object based on identity. + + Equivalent to :meth:`.Query.get`. + + """ + + query = self.bq.steps[0](self.session) + return query._get_impl(ident, self._load_on_ident) + + def _load_on_ident(self, query, key): + """Load the given identity key from the database.""" + + ident = key[1] + + mapper = query._mapper_zero() + + _get_clause, _get_params = mapper._get_clause + + def setup(query): + _lcl_get_clause = _get_clause + q = query._clone() + q._get_condition() + q._order_by = None + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in ident: + nones = set([ + _get_params[col].key for col, value in + zip(mapper.primary_key, ident) if value is None + ]) + _lcl_get_clause = sql_util.adapt_criterion_to_null( + _lcl_get_clause, nones) + + _lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False) + q._criterion = _lcl_get_clause + return q + + # cache the query against a key that includes + # which positions in the primary key are NULL + # (remember, we can map to an OUTER JOIN) + bq = self.bq + + bq = bq.with_criteria(setup, tuple(elem is None for elem in ident)) + + params = dict([ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip(ident, mapper.primary_key) + ]) + + result = list(bq.for_session(self.session).params(**params)) + l = len(result) + if l > 1: + raise orm_exc.MultipleResultsFound() + elif l: + return result[0] + else: + return None + + +def bake_lazy_loaders(): + """Enable the use of baked queries for all lazyloaders systemwide. + + This operation should be safe for all lazy loaders, and will reduce + Python overhead for these operations. + + """ + strategies.LazyLoader._strategy_keys[:] = [] + BakedLazyLoader._strategy_keys[:] = [] + + properties.RelationshipProperty.strategy_for( + lazy="select")(BakedLazyLoader) + properties.RelationshipProperty.strategy_for( + lazy=True)(BakedLazyLoader) + properties.RelationshipProperty.strategy_for( + lazy="baked_select")(BakedLazyLoader) + + +def unbake_lazy_loaders(): + """Disable the use of baked queries for all lazyloaders systemwide. + + This operation reverts the changes produced by :func:`.bake_lazy_loaders`. + + """ + strategies.LazyLoader._strategy_keys[:] = [] + BakedLazyLoader._strategy_keys[:] = [] + + properties.RelationshipProperty.strategy_for( + lazy="select")(strategies.LazyLoader) + properties.RelationshipProperty.strategy_for( + lazy=True)(strategies.LazyLoader) + properties.RelationshipProperty.strategy_for( + lazy="baked_select")(BakedLazyLoader) + assert strategies.LazyLoader._strategy_keys + + +@sqla_log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="baked_select") +class BakedLazyLoader(strategies.LazyLoader): + + def _emit_lazyload(self, session, state, ident_key, passive): + q = BakedQuery( + self.mapper._compiled_cache, + lambda session: session.query(self.mapper)) + q.add_criteria( + lambda q: q._adapt_all_clauses()._with_invoke_all_eagers(False), + self.parent_property) + + if not self.parent_property.bake_queries: + q.spoil(full=True) + + if self.parent_property.secondary is not None: + q.add_criteria( + lambda q: + q.select_from(self.mapper, self.parent_property.secondary)) + + pending = not state.key + + # don't autoflush on pending + if pending or passive & attributes.NO_AUTOFLUSH: + q.add_criteria(lambda q: q.autoflush(False)) + + if state.load_path: + q.spoil() + q.add_criteria( + lambda q: + q._with_current_path(state.load_path[self.parent_property])) + + if state.load_options: + q.spoil() + q.add_criteria( + lambda q: q._conditional_options(*state.load_options)) + + if self.use_get: + return q(session)._load_on_ident( + session.query(self.mapper), ident_key) + + if self.parent_property.order_by: + q.add_criteria( + lambda q: + q.order_by(*util.to_list(self.parent_property.order_by))) + + for rev in self.parent_property._reverse_property: + # reverse props that are MANYTOONE are loading *this* + # object from get(), so don't need to eager out to those. + if rev.direction is interfaces.MANYTOONE and \ + rev._use_get and \ + not isinstance(rev.strategy, strategies.LazyLoader): + q.add_criteria( + lambda q: + q.options( + strategy_options.Load( + rev.parent).baked_lazyload(rev.key))) + + lazy_clause, params = self._generate_lazy_clause(state, passive) + + if pending: + if orm_util._none_set.intersection(params.values()): + return None + + q.add_criteria(lambda q: q.filter(lazy_clause)) + result = q(session).params(**params).all() + if self.uselist: + return result + else: + l = len(result) + if l: + if l > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for lazily-loaded attribute '%s' " + % self.parent_property) + + return result[0] + else: + return None + + +@strategy_options.loader_option() +def baked_lazyload(loadopt, attr): + """Indicate that the given attribute should be loaded using "lazy" + loading with a "baked" query used in the load. + + """ + return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"}) + + +@baked_lazyload._add_unbound_fn +def baked_lazyload(*keys): + return strategy_options._UnboundLoad._from_keys( + strategy_options._UnboundLoad.baked_lazyload, keys, False, {}) + + +@baked_lazyload._add_unbound_all_fn +def baked_lazyload_all(*keys): + return strategy_options._UnboundLoad._from_keys( + strategy_options._UnboundLoad.baked_lazyload, keys, True, {}) + +baked_lazyload = baked_lazyload._unbound_fn +baked_lazyload_all = baked_lazyload_all._unbound_all_fn + +bakery = BakedQuery.bakery diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 8d169aa57..9717e41c0 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -1,5 +1,5 @@ # ext/compiler.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index 2b611252a..f703000bb 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -1,1381 +1,10 @@ # ext/declarative/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -""" -Synopsis -======== - -SQLAlchemy object-relational configuration involves the -combination of :class:`.Table`, :func:`.mapper`, and class -objects to define a mapped class. -:mod:`~sqlalchemy.ext.declarative` allows all three to be -expressed at once within the class declaration. As much as -possible, regular SQLAlchemy schema and ORM constructs are -used directly, so that configuration between "classical" ORM -usage and declarative remain highly similar. - -As a simple example:: - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base() - - class SomeClass(Base): - __tablename__ = 'some_table' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - -Above, the :func:`declarative_base` callable returns a new base class from -which all mapped classes should inherit. When the class definition is -completed, a new :class:`.Table` and :func:`.mapper` will have been generated. - -The resulting table and mapper are accessible via -``__table__`` and ``__mapper__`` attributes on the -``SomeClass`` class:: - - # access the mapped Table - SomeClass.__table__ - - # access the Mapper - SomeClass.__mapper__ - -Defining Attributes -=================== - -In the previous example, the :class:`.Column` objects are -automatically named with the name of the attribute to which they are -assigned. - -To name columns explicitly with a name distinct from their mapped attribute, -just give the column a name. Below, column "some_table_id" is mapped to the -"id" attribute of `SomeClass`, but in SQL will be represented as -"some_table_id":: - - class SomeClass(Base): - __tablename__ = 'some_table' - id = Column("some_table_id", Integer, primary_key=True) - -Attributes may be added to the class after its construction, and they will be -added to the underlying :class:`.Table` and -:func:`.mapper` definitions as appropriate:: - - SomeClass.data = Column('data', Unicode) - SomeClass.related = relationship(RelatedInfo) - -Classes which are constructed using declarative can interact freely -with classes that are mapped explicitly with :func:`.mapper`. - -It is recommended, though not required, that all tables -share the same underlying :class:`~sqlalchemy.schema.MetaData` object, -so that string-configured :class:`~sqlalchemy.schema.ForeignKey` -references can be resolved without issue. - -Accessing the MetaData -======================= - -The :func:`declarative_base` base class contains a -:class:`.MetaData` object where newly defined -:class:`.Table` objects are collected. This object is -intended to be accessed directly for -:class:`.MetaData`-specific operations. Such as, to issue -CREATE statements for all tables:: - - engine = create_engine('sqlite://') - Base.metadata.create_all(engine) - -:func:`declarative_base` can also receive a pre-existing -:class:`.MetaData` object, which allows a -declarative setup to be associated with an already -existing traditional collection of :class:`~sqlalchemy.schema.Table` -objects:: - - mymetadata = MetaData() - Base = declarative_base(metadata=mymetadata) - - -.. _declarative_configuring_relationships: - -Configuring Relationships -========================= - -Relationships to other classes are done in the usual way, with the added -feature that the class specified to :func:`~sqlalchemy.orm.relationship` -may be a string name. The "class registry" associated with ``Base`` -is used at mapper compilation time to resolve the name into the actual -class object, which is expected to have been defined once the mapper -configuration is used:: - - class User(Base): - __tablename__ = 'users' - - id = Column(Integer, primary_key=True) - name = Column(String(50)) - addresses = relationship("Address", backref="user") - - class Address(Base): - __tablename__ = 'addresses' - - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - -Column constructs, since they are just that, are immediately usable, -as below where we define a primary join condition on the ``Address`` -class using them:: - - class Address(Base): - __tablename__ = 'addresses' - - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - user = relationship(User, primaryjoin=user_id == User.id) - -In addition to the main argument for :func:`~sqlalchemy.orm.relationship`, -other arguments which depend upon the columns present on an as-yet -undefined class may also be specified as strings. These strings are -evaluated as Python expressions. The full namespace available within -this evaluation includes all classes mapped for this declarative base, -as well as the contents of the ``sqlalchemy`` package, including -expression functions like :func:`~sqlalchemy.sql.expression.desc` and -:attr:`~sqlalchemy.sql.expression.func`:: - - class User(Base): - # .... - addresses = relationship("Address", - order_by="desc(Address.email)", - primaryjoin="Address.user_id==User.id") - -For the case where more than one module contains a class of the same name, -string class names can also be specified as module-qualified paths -within any of these string expressions:: - - class User(Base): - # .... - addresses = relationship("myapp.model.address.Address", - order_by="desc(myapp.model.address.Address.email)", - primaryjoin="myapp.model.address.Address.user_id==" - "myapp.model.user.User.id") - -The qualified path can be any partial path that removes ambiguity between -the names. For example, to disambiguate between -``myapp.model.address.Address`` and ``myapp.model.lookup.Address``, -we can specify ``address.Address`` or ``lookup.Address``:: - - class User(Base): - # .... - addresses = relationship("address.Address", - order_by="desc(address.Address.email)", - primaryjoin="address.Address.user_id==" - "User.id") - -.. versionadded:: 0.8 - module-qualified paths can be used when specifying string arguments - with Declarative, in order to specify specific modules. - -Two alternatives also exist to using string-based attributes. A lambda -can also be used, which will be evaluated after all mappers have been -configured:: - - class User(Base): - # ... - addresses = relationship(lambda: Address, - order_by=lambda: desc(Address.email), - primaryjoin=lambda: Address.user_id==User.id) - -Or, the relationship can be added to the class explicitly after the classes -are available:: - - User.addresses = relationship(Address, - primaryjoin=Address.user_id==User.id) - - - -.. _declarative_many_to_many: - -Configuring Many-to-Many Relationships -====================================== - -Many-to-many relationships are also declared in the same way -with declarative as with traditional mappings. The -``secondary`` argument to -:func:`.relationship` is as usual passed a -:class:`.Table` object, which is typically declared in the -traditional way. The :class:`.Table` usually shares -the :class:`.MetaData` object used by the declarative base:: - - keywords = Table( - 'keywords', Base.metadata, - Column('author_id', Integer, ForeignKey('authors.id')), - Column('keyword_id', Integer, ForeignKey('keywords.id')) - ) - - class Author(Base): - __tablename__ = 'authors' - id = Column(Integer, primary_key=True) - keywords = relationship("Keyword", secondary=keywords) - -Like other :func:`~sqlalchemy.orm.relationship` arguments, a string is accepted -as well, passing the string name of the table as defined in the -``Base.metadata.tables`` collection:: - - class Author(Base): - __tablename__ = 'authors' - id = Column(Integer, primary_key=True) - keywords = relationship("Keyword", secondary="keywords") - -As with traditional mapping, its generally not a good idea to use -a :class:`.Table` as the "secondary" argument which is also mapped to -a class, unless the :func:`.relationship` is declared with ``viewonly=True``. -Otherwise, the unit-of-work system may attempt duplicate INSERT and -DELETE statements against the underlying table. - -.. _declarative_sql_expressions: - -Defining SQL Expressions -======================== - -See :ref:`mapper_sql_expressions` for examples on declaratively -mapping attributes to SQL expressions. - -.. _declarative_table_args: - -Table Configuration -=================== - -Table arguments other than the name, metadata, and mapped Column -arguments are specified using the ``__table_args__`` class attribute. -This attribute accommodates both positional as well as keyword -arguments that are normally sent to the -:class:`~sqlalchemy.schema.Table` constructor. -The attribute can be specified in one of two forms. One is as a -dictionary:: - - class MyClass(Base): - __tablename__ = 'sometable' - __table_args__ = {'mysql_engine':'InnoDB'} - -The other, a tuple, where each argument is positional -(usually constraints):: - - class MyClass(Base): - __tablename__ = 'sometable' - __table_args__ = ( - ForeignKeyConstraint(['id'], ['remote_table.id']), - UniqueConstraint('foo'), - ) - -Keyword arguments can be specified with the above form by -specifying the last argument as a dictionary:: - - class MyClass(Base): - __tablename__ = 'sometable' - __table_args__ = ( - ForeignKeyConstraint(['id'], ['remote_table.id']), - UniqueConstraint('foo'), - {'autoload':True} - ) - -Using a Hybrid Approach with __table__ -======================================= - -As an alternative to ``__tablename__``, a direct -:class:`~sqlalchemy.schema.Table` construct may be used. The -:class:`~sqlalchemy.schema.Column` objects, which in this case require -their names, will be added to the mapping just like a regular mapping -to a table:: - - class MyClass(Base): - __table__ = Table('my_table', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) - -``__table__`` provides a more focused point of control for establishing -table metadata, while still getting most of the benefits of using declarative. -An application that uses reflection might want to load table metadata elsewhere -and pass it to declarative classes:: - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base() - Base.metadata.reflect(some_engine) - - class User(Base): - __table__ = metadata.tables['user'] - - class Address(Base): - __table__ = metadata.tables['address'] - -Some configuration schemes may find it more appropriate to use ``__table__``, -such as those which already take advantage of the data-driven nature of -:class:`.Table` to customize and/or automate schema definition. - -Note that when the ``__table__`` approach is used, the object is immediately -usable as a plain :class:`.Table` within the class declaration body itself, -as a Python class is only another syntactical block. Below this is illustrated -by using the ``id`` column in the ``primaryjoin`` condition of a -:func:`.relationship`:: - - class MyClass(Base): - __table__ = Table('my_table', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) - - widgets = relationship(Widget, - primaryjoin=Widget.myclass_id==__table__.c.id) - -Similarly, mapped attributes which refer to ``__table__`` can be placed inline, -as below where we assign the ``name`` column to the attribute ``_name``, -generating a synonym for ``name``:: - - from sqlalchemy.ext.declarative import synonym_for - - class MyClass(Base): - __table__ = Table('my_table', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) - - _name = __table__.c.name - - @synonym_for("_name") - def name(self): - return "Name: %s" % _name - -Using Reflection with Declarative -================================= - -It's easy to set up a :class:`.Table` that uses ``autoload=True`` -in conjunction with a mapped class:: - - class MyClass(Base): - __table__ = Table('mytable', Base.metadata, - autoload=True, autoload_with=some_engine) - -However, one improvement that can be made here is to not -require the :class:`.Engine` to be available when classes are -being first declared. To achieve this, use the -:class:`.DeferredReflection` mixin, which sets up mappings -only after a special ``prepare(engine)`` step is called:: - - from sqlalchemy.ext.declarative import declarative_base, DeferredReflection - - Base = declarative_base(cls=DeferredReflection) - - class Foo(Base): - __tablename__ = 'foo' - bars = relationship("Bar") - - class Bar(Base): - __tablename__ = 'bar' - - # illustrate overriding of "bar.foo_id" to have - # a foreign key constraint otherwise not - # reflected, such as when using MySQL - foo_id = Column(Integer, ForeignKey('foo.id')) - - Base.prepare(e) - -.. versionadded:: 0.8 - Added :class:`.DeferredReflection`. - -Mapper Configuration -==================== - -Declarative makes use of the :func:`~.orm.mapper` function internally -when it creates the mapping to the declared table. The options -for :func:`~.orm.mapper` are passed directly through via the -``__mapper_args__`` class attribute. As always, arguments which reference -locally mapped columns can reference them directly from within the -class declaration:: - - from datetime import datetime - - class Widget(Base): - __tablename__ = 'widgets' - - id = Column(Integer, primary_key=True) - timestamp = Column(DateTime, nullable=False) - - __mapper_args__ = { - 'version_id_col': timestamp, - 'version_id_generator': lambda v:datetime.now() - } - -.. _declarative_inheritance: - -Inheritance Configuration -========================= - -Declarative supports all three forms of inheritance as intuitively -as possible. The ``inherits`` mapper keyword argument is not needed -as declarative will determine this from the class itself. The various -"polymorphic" keyword arguments are specified using ``__mapper_args__``. - -Joined Table Inheritance -~~~~~~~~~~~~~~~~~~~~~~~~ - -Joined table inheritance is defined as a subclass that defines its own -table:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - id = Column(Integer, ForeignKey('people.id'), primary_key=True) - primary_language = Column(String(50)) - -Note that above, the ``Engineer.id`` attribute, since it shares the -same attribute name as the ``Person.id`` attribute, will in fact -represent the ``people.id`` and ``engineers.id`` columns together, -with the "Engineer.id" column taking precedence if queried directly. -To provide the ``Engineer`` class with an attribute that represents -only the ``engineers.id`` column, give it a different attribute name:: - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - engineer_id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) - primary_language = Column(String(50)) - - -.. versionchanged:: 0.7 joined table inheritance favors the subclass - column over that of the superclass, such as querying above - for ``Engineer.id``. Prior to 0.7 this was the reverse. - -.. _declarative_single_table: - -Single Table Inheritance -~~~~~~~~~~~~~~~~~~~~~~~~ - -Single table inheritance is defined as a subclass that does not have -its own table; you just leave out the ``__table__`` and ``__tablename__`` -attributes:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - primary_language = Column(String(50)) - -When the above mappers are configured, the ``Person`` class is mapped -to the ``people`` table *before* the ``primary_language`` column is -defined, and this column will not be included in its own mapping. -When ``Engineer`` then defines the ``primary_language`` column, the -column is added to the ``people`` table so that it is included in the -mapping for ``Engineer`` and is also part of the table's full set of -columns. Columns which are not mapped to ``Person`` are also excluded -from any other single or joined inheriting classes using the -``exclude_properties`` mapper argument. Below, ``Manager`` will have -all the attributes of ``Person`` and ``Manager`` but *not* the -``primary_language`` attribute of ``Engineer``:: - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - golf_swing = Column(String(50)) - -The attribute exclusion logic is provided by the -``exclude_properties`` mapper argument, and declarative's default -behavior can be disabled by passing an explicit ``exclude_properties`` -collection (empty or otherwise) to the ``__mapper_args__``. - -Resolving Column Conflicts -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Note above that the ``primary_language`` and ``golf_swing`` columns -are "moved up" to be applied to ``Person.__table__``, as a result of their -declaration on a subclass that has no table of its own. A tricky case -comes up when two subclasses want to specify *the same* column, as below:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - start_date = Column(DateTime) - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - start_date = Column(DateTime) - -Above, the ``start_date`` column declared on both ``Engineer`` and ``Manager`` -will result in an error:: - - sqlalchemy.exc.ArgumentError: Column 'start_date' on class - <class '__main__.Manager'> conflicts with existing - column 'people.start_date' - -In a situation like this, Declarative can't be sure -of the intent, especially if the ``start_date`` columns had, for example, -different types. A situation like this can be resolved by using -:class:`.declared_attr` to define the :class:`.Column` conditionally, taking -care to return the **existing column** via the parent ``__table__`` if it -already exists:: - - from sqlalchemy.ext.declarative import declared_attr - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - - @declared_attr - def start_date(cls): - "Start date column, if not present already." - return Person.__table__.c.get('start_date', Column(DateTime)) - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - - @declared_attr - def start_date(cls): - "Start date column, if not present already." - return Person.__table__.c.get('start_date', Column(DateTime)) - -Above, when ``Manager`` is mapped, the ``start_date`` column is -already present on the ``Person`` class. Declarative lets us return -that :class:`.Column` as a result in this case, where it knows to skip -re-assigning the same column. If the mapping is mis-configured such -that the ``start_date`` column is accidentally re-assigned to a -different table (such as, if we changed ``Manager`` to be joined -inheritance without fixing ``start_date``), an error is raised which -indicates an existing :class:`.Column` is trying to be re-assigned to -a different owning :class:`.Table`. - -.. versionadded:: 0.8 :class:`.declared_attr` can be used on a non-mixin - class, and the returned :class:`.Column` or other mapped attribute - will be applied to the mapping as any other attribute. Previously, - the resulting attribute would be ignored, and also result in a warning - being emitted when a subclass was created. - -.. versionadded:: 0.8 :class:`.declared_attr`, when used either with a - mixin or non-mixin declarative class, can return an existing - :class:`.Column` already assigned to the parent :class:`.Table`, - to indicate that the re-assignment of the :class:`.Column` should be - skipped, however should still be mapped on the target class, - in order to resolve duplicate column conflicts. - -The same concept can be used with mixin classes (see -:ref:`declarative_mixins`):: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class HasStartDate(object): - @declared_attr - def start_date(cls): - return cls.__table__.c.get('start_date', Column(DateTime)) - - class Engineer(HasStartDate, Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - - class Manager(HasStartDate, Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - -The above mixin checks the local ``__table__`` attribute for the column. -Because we're using single table inheritance, we're sure that in this case, -``cls.__table__`` refers to ``People.__table__``. If we were mixing joined- -and single-table inheritance, we might want our mixin to check more carefully -if ``cls.__table__`` is really the :class:`.Table` we're looking for. - -Concrete Table Inheritance -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Concrete is defined as a subclass which has its own table and sets the -``concrete`` keyword argument to ``True``:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'concrete':True} - id = Column(Integer, primary_key=True) - primary_language = Column(String(50)) - name = Column(String(50)) - -Usage of an abstract base class is a little less straightforward as it -requires usage of :func:`~sqlalchemy.orm.util.polymorphic_union`, -which needs to be created with the :class:`.Table` objects -before the class is built:: - - engineers = Table('engineers', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('primary_language', String(50)) - ) - managers = Table('managers', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('golf_swing', String(50)) - ) - - punion = polymorphic_union({ - 'engineer':engineers, - 'manager':managers - }, 'type', 'punion') - - class Person(Base): - __table__ = punion - __mapper_args__ = {'polymorphic_on':punion.c.type} - - class Engineer(Person): - __table__ = engineers - __mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True} - - class Manager(Person): - __table__ = managers - __mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True} - -.. _declarative_concrete_helpers: - -Using the Concrete Helpers -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Helper classes provides a simpler pattern for concrete inheritance. -With these objects, the ``__declare_first__`` helper is used to configure the -"polymorphic" loader for the mapper after all subclasses have been declared. - -.. versionadded:: 0.7.3 - -An abstract base can be declared using the -:class:`.AbstractConcreteBase` class:: - - from sqlalchemy.ext.declarative import AbstractConcreteBase - - class Employee(AbstractConcreteBase, Base): - pass - -To have a concrete ``employee`` table, use :class:`.ConcreteBase` instead:: - - from sqlalchemy.ext.declarative import ConcreteBase - - class Employee(ConcreteBase, Base): - __tablename__ = 'employee' - employee_id = Column(Integer, primary_key=True) - name = Column(String(50)) - __mapper_args__ = { - 'polymorphic_identity':'employee', - 'concrete':True} - - -Either ``Employee`` base can be used in the normal fashion:: - - class Manager(Employee): - __tablename__ = 'manager' - employee_id = Column(Integer, primary_key=True) - name = Column(String(50)) - manager_data = Column(String(40)) - __mapper_args__ = { - 'polymorphic_identity':'manager', - 'concrete':True} - - class Engineer(Employee): - __tablename__ = 'engineer' - employee_id = Column(Integer, primary_key=True) - name = Column(String(50)) - engineer_info = Column(String(40)) - __mapper_args__ = {'polymorphic_identity':'engineer', - 'concrete':True} - - -The :class:`.AbstractConcreteBase` class is itself mapped, and can be -used as a target of relationships:: - - class Company(Base): - __tablename__ = 'company' - - id = Column(Integer, primary_key=True) - employees = relationship("Employee", - primaryjoin="Company.id == Employee.company_id") - - -.. versionchanged:: 0.9.3 Support for use of :class:`.AbstractConcreteBase` - as the target of a :func:`.relationship` has been improved. - -It can also be queried directly:: - - for employee in session.query(Employee).filter(Employee.name == 'qbert'): - print(employee) - - -.. _declarative_mixins: - -Mixin and Custom Base Classes -============================== - -A common need when using :mod:`~sqlalchemy.ext.declarative` is to -share some functionality, such as a set of common columns, some common -table options, or other mapped properties, across many -classes. The standard Python idioms for this is to have the classes -inherit from a base which includes these common features. - -When using :mod:`~sqlalchemy.ext.declarative`, this idiom is allowed -via the usage of a custom declarative base class, as well as a "mixin" class -which is inherited from in addition to the primary base. Declarative -includes several helper features to make this work in terms of how -mappings are declared. An example of some commonly mixed-in -idioms is below:: - - from sqlalchemy.ext.declarative import declared_attr - - class MyMixin(object): - - @declared_attr - def __tablename__(cls): - return cls.__name__.lower() - - __table_args__ = {'mysql_engine': 'InnoDB'} - __mapper_args__= {'always_refresh': True} - - id = Column(Integer, primary_key=True) - - class MyModel(MyMixin, Base): - name = Column(String(1000)) - -Where above, the class ``MyModel`` will contain an "id" column -as the primary key, a ``__tablename__`` attribute that derives -from the name of the class itself, as well as ``__table_args__`` -and ``__mapper_args__`` defined by the ``MyMixin`` mixin class. - -There's no fixed convention over whether ``MyMixin`` precedes -``Base`` or not. Normal Python method resolution rules apply, and -the above example would work just as well with:: - - class MyModel(Base, MyMixin): - name = Column(String(1000)) - -This works because ``Base`` here doesn't define any of the -variables that ``MyMixin`` defines, i.e. ``__tablename__``, -``__table_args__``, ``id``, etc. If the ``Base`` did define -an attribute of the same name, the class placed first in the -inherits list would determine which attribute is used on the -newly defined class. - -Augmenting the Base -~~~~~~~~~~~~~~~~~~~ - -In addition to using a pure mixin, most of the techniques in this -section can also be applied to the base class itself, for patterns that -should apply to all classes derived from a particular base. This is achieved -using the ``cls`` argument of the :func:`.declarative_base` function:: - - from sqlalchemy.ext.declarative import declared_attr - - class Base(object): - @declared_attr - def __tablename__(cls): - return cls.__name__.lower() - - __table_args__ = {'mysql_engine': 'InnoDB'} - - id = Column(Integer, primary_key=True) - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base(cls=Base) - - class MyModel(Base): - name = Column(String(1000)) - -Where above, ``MyModel`` and all other classes that derive from ``Base`` will -have a table name derived from the class name, an ``id`` primary key column, -as well as the "InnoDB" engine for MySQL. - -Mixing in Columns -~~~~~~~~~~~~~~~~~ - -The most basic way to specify a column on a mixin is by simple -declaration:: - - class TimestampMixin(object): - created_at = Column(DateTime, default=func.now()) - - class MyModel(TimestampMixin, Base): - __tablename__ = 'test' - - id = Column(Integer, primary_key=True) - name = Column(String(1000)) - -Where above, all declarative classes that include ``TimestampMixin`` -will also have a column ``created_at`` that applies a timestamp to -all row insertions. - -Those familiar with the SQLAlchemy expression language know that -the object identity of clause elements defines their role in a schema. -Two ``Table`` objects ``a`` and ``b`` may both have a column called -``id``, but the way these are differentiated is that ``a.c.id`` -and ``b.c.id`` are two distinct Python objects, referencing their -parent tables ``a`` and ``b`` respectively. - -In the case of the mixin column, it seems that only one -:class:`.Column` object is explicitly created, yet the ultimate -``created_at`` column above must exist as a distinct Python object -for each separate destination class. To accomplish this, the declarative -extension creates a **copy** of each :class:`.Column` object encountered on -a class that is detected as a mixin. - -This copy mechanism is limited to simple columns that have no foreign -keys, as a :class:`.ForeignKey` itself contains references to columns -which can't be properly recreated at this level. For columns that -have foreign keys, as well as for the variety of mapper-level constructs -that require destination-explicit context, the -:class:`~.declared_attr` decorator is provided so that -patterns common to many classes can be defined as callables:: - - from sqlalchemy.ext.declarative import declared_attr - - class ReferenceAddressMixin(object): - @declared_attr - def address_id(cls): - return Column(Integer, ForeignKey('address.id')) - - class User(ReferenceAddressMixin, Base): - __tablename__ = 'user' - id = Column(Integer, primary_key=True) - -Where above, the ``address_id`` class-level callable is executed at the -point at which the ``User`` class is constructed, and the declarative -extension can use the resulting :class:`.Column` object as returned by -the method without the need to copy it. - -.. versionchanged:: > 0.6.5 - Rename 0.6.5 ``sqlalchemy.util.classproperty`` - into :class:`~.declared_attr`. - -Columns generated by :class:`~.declared_attr` can also be -referenced by ``__mapper_args__`` to a limited degree, currently -by ``polymorphic_on`` and ``version_id_col``; the declarative extension -will resolve them at class construction time:: - - class MyMixin: - @declared_attr - def type_(cls): - return Column(String(50)) - - __mapper_args__= {'polymorphic_on':type_} - - class MyModel(MyMixin, Base): - __tablename__='test' - id = Column(Integer, primary_key=True) - - -Mixing in Relationships -~~~~~~~~~~~~~~~~~~~~~~~ - -Relationships created by :func:`~sqlalchemy.orm.relationship` are provided -with declarative mixin classes exclusively using the -:class:`.declared_attr` approach, eliminating any ambiguity -which could arise when copying a relationship and its possibly column-bound -contents. Below is an example which combines a foreign key column and a -relationship so that two classes ``Foo`` and ``Bar`` can both be configured to -reference a common target class via many-to-one:: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship("Target") - - class Foo(RefTargetMixin, Base): - __tablename__ = 'foo' - id = Column(Integer, primary_key=True) - - class Bar(RefTargetMixin, Base): - __tablename__ = 'bar' - id = Column(Integer, primary_key=True) - - class Target(Base): - __tablename__ = 'target' - id = Column(Integer, primary_key=True) - - -Using Advanced Relationship Arguments (e.g. ``primaryjoin``, etc.) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:func:`~sqlalchemy.orm.relationship` definitions which require explicit -primaryjoin, order_by etc. expressions should in all but the most -simplistic cases use **late bound** forms -for these arguments, meaning, using either the string form or a lambda. -The reason for this is that the related :class:`.Column` objects which are to -be configured using ``@declared_attr`` are not available to another -``@declared_attr`` attribute; while the methods will work and return new -:class:`.Column` objects, those are not the :class:`.Column` objects that -Declarative will be using as it calls the methods on its own, thus using -*different* :class:`.Column` objects. - -The canonical example is the primaryjoin condition that depends upon -another mixed-in column:: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship(Target, - primaryjoin=Target.id==cls.target_id # this is *incorrect* - ) - -Mapping a class using the above mixin, we will get an error like:: - - sqlalchemy.exc.InvalidRequestError: this ForeignKey's parent column is not - yet associated with a Table. - -This is because the ``target_id`` :class:`.Column` we've called upon in our -``target()`` method is not the same :class:`.Column` that declarative is -actually going to map to our table. - -The condition above is resolved using a lambda:: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship(Target, - primaryjoin=lambda: Target.id==cls.target_id - ) - -or alternatively, the string form (which ultimately generates a lambda):: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship("Target", - primaryjoin="Target.id==%s.target_id" % cls.__name__ - ) - -Mixing in deferred(), column_property(), and other MapperProperty classes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Like :func:`~sqlalchemy.orm.relationship`, all -:class:`~sqlalchemy.orm.interfaces.MapperProperty` subclasses such as -:func:`~sqlalchemy.orm.deferred`, :func:`~sqlalchemy.orm.column_property`, -etc. ultimately involve references to columns, and therefore, when -used with declarative mixins, have the :class:`.declared_attr` -requirement so that no reliance on copying is needed:: - - class SomethingMixin(object): - - @declared_attr - def dprop(cls): - return deferred(Column(Integer)) - - class Something(SomethingMixin, Base): - __tablename__ = "something" - -The :func:`.column_property` or other construct may refer -to other columns from the mixin. These are copied ahead of time before -the :class:`.declared_attr` is invoked:: - - class SomethingMixin(object): - x = Column(Integer) - - y = Column(Integer) - - @declared_attr - def x_plus_y(cls): - return column_property(cls.x + cls.y) - - -.. versionchanged:: 1.0.0 mixin columns are copied to the final mapped class - so that :class:`.declared_attr` methods can access the actual column - that will be mapped. - -Mixing in Association Proxy and Other Attributes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Mixins can specify user-defined attributes as well as other extension -units such as :func:`.association_proxy`. The usage of -:class:`.declared_attr` is required in those cases where the attribute must -be tailored specifically to the target subclass. An example is when -constructing multiple :func:`.association_proxy` attributes which each -target a different type of child object. Below is an -:func:`.association_proxy` / mixin example which provides a scalar list of -string values to an implementing class:: - - from sqlalchemy import Column, Integer, ForeignKey, String - from sqlalchemy.orm import relationship - from sqlalchemy.ext.associationproxy import association_proxy - from sqlalchemy.ext.declarative import declarative_base, declared_attr - - Base = declarative_base() - - class HasStringCollection(object): - @declared_attr - def _strings(cls): - class StringAttribute(Base): - __tablename__ = cls.string_table_name - id = Column(Integer, primary_key=True) - value = Column(String(50), nullable=False) - parent_id = Column(Integer, - ForeignKey('%s.id' % cls.__tablename__), - nullable=False) - def __init__(self, value): - self.value = value - - return relationship(StringAttribute) - - @declared_attr - def strings(cls): - return association_proxy('_strings', 'value') - - class TypeA(HasStringCollection, Base): - __tablename__ = 'type_a' - string_table_name = 'type_a_strings' - id = Column(Integer(), primary_key=True) - - class TypeB(HasStringCollection, Base): - __tablename__ = 'type_b' - string_table_name = 'type_b_strings' - id = Column(Integer(), primary_key=True) - -Above, the ``HasStringCollection`` mixin produces a :func:`.relationship` -which refers to a newly generated class called ``StringAttribute``. The -``StringAttribute`` class is generated with its own :class:`.Table` -definition which is local to the parent class making usage of the -``HasStringCollection`` mixin. It also produces an :func:`.association_proxy` -object which proxies references to the ``strings`` attribute onto the ``value`` -attribute of each ``StringAttribute`` instance. - -``TypeA`` or ``TypeB`` can be instantiated given the constructor -argument ``strings``, a list of strings:: - - ta = TypeA(strings=['foo', 'bar']) - tb = TypeA(strings=['bat', 'bar']) - -This list will generate a collection -of ``StringAttribute`` objects, which are persisted into a table that's -local to either the ``type_a_strings`` or ``type_b_strings`` table:: - - >>> print ta._strings - [<__main__.StringAttribute object at 0x10151cd90>, - <__main__.StringAttribute object at 0x10151ce10>] - -When constructing the :func:`.association_proxy`, the -:class:`.declared_attr` decorator must be used so that a distinct -:func:`.association_proxy` object is created for each of the ``TypeA`` -and ``TypeB`` classes. - -.. versionadded:: 0.8 :class:`.declared_attr` is usable with non-mapped - attributes, including user-defined attributes as well as - :func:`.association_proxy`. - - -Controlling table inheritance with mixins -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The ``__tablename__`` attribute may be used to provide a function that -will determine the name of the table used for each class in an inheritance -hierarchy, as well as whether a class has its own distinct table. - -This is achieved using the :class:`.declared_attr` indicator in conjunction -with a method named ``__tablename__()``. Declarative will always -invoke :class:`.declared_attr` for the special names -``__tablename__``, ``__mapper_args__`` and ``__table_args__`` -function **for each mapped class in the hierarchy**. The function therefore -needs to expect to receive each class individually and to provide the -correct answer for each. - -For example, to create a mixin that gives every class a simple table -name based on class name:: - - from sqlalchemy.ext.declarative import declared_attr - - class Tablename: - @declared_attr - def __tablename__(cls): - return cls.__name__.lower() - - class Person(Tablename, Base): - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = None - __mapper_args__ = {'polymorphic_identity': 'engineer'} - primary_language = Column(String(50)) - -Alternatively, we can modify our ``__tablename__`` function to return -``None`` for subclasses, using :func:`.has_inherited_table`. This has -the effect of those subclasses being mapped with single table inheritance -agaisnt the parent:: - - from sqlalchemy.ext.declarative import declared_attr - from sqlalchemy.ext.declarative import has_inherited_table - - class Tablename(object): - @declared_attr - def __tablename__(cls): - if has_inherited_table(cls): - return None - return cls.__name__.lower() - - class Person(Tablename, Base): - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - primary_language = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'engineer'} - -.. _mixin_inheritance_columns: - -Mixing in Columns in Inheritance Scenarios -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In constrast to how ``__tablename__`` and other special names are handled when -used with :class:`.declared_attr`, when we mix in columns and properties (e.g. -relationships, column properties, etc.), the function is -invoked for the **base class only** in the hierarchy. Below, only the -``Person`` class will receive a column -called ``id``; the mapping will fail on ``Engineer``, which is not given -a primary key:: - - class HasId(object): - @declared_attr - def id(cls): - return Column('id', Integer, primary_key=True) - - class Person(HasId, Base): - __tablename__ = 'person' - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = 'engineer' - primary_language = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'engineer'} - -It is usually the case in joined-table inheritance that we want distinctly -named columns on each subclass. However in this case, we may want to have -an ``id`` column on every table, and have them refer to each other via -foreign key. We can achieve this as a mixin by using the -:attr:`.declared_attr.cascading` modifier, which indicates that the -function should be invoked **for each class in the hierarchy**, just like -it does for ``__tablename__``:: - - class HasId(object): - @declared_attr.cascading - def id(cls): - if has_inherited_table(cls): - return Column('id', - Integer, - ForeignKey('person.id'), primary_key=True) - else: - return Column('id', Integer, primary_key=True) - - class Person(HasId, Base): - __tablename__ = 'person' - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = 'engineer' - primary_language = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'engineer'} - - -.. versionadded:: 1.0.0 added :attr:`.declared_attr.cascading`. - -Combining Table/Mapper Arguments from Multiple Mixins -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In the case of ``__table_args__`` or ``__mapper_args__`` -specified with declarative mixins, you may want to combine -some parameters from several mixins with those you wish to -define on the class iteself. The -:class:`.declared_attr` decorator can be used -here to create user-defined collation routines that pull -from multiple collections:: - - from sqlalchemy.ext.declarative import declared_attr - - class MySQLSettings(object): - __table_args__ = {'mysql_engine':'InnoDB'} - - class MyOtherMixin(object): - __table_args__ = {'info':'foo'} - - class MyModel(MySQLSettings, MyOtherMixin, Base): - __tablename__='my_model' - - @declared_attr - def __table_args__(cls): - args = dict() - args.update(MySQLSettings.__table_args__) - args.update(MyOtherMixin.__table_args__) - return args - - id = Column(Integer, primary_key=True) - -Creating Indexes with Mixins -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To define a named, potentially multicolumn :class:`.Index` that applies to all -tables derived from a mixin, use the "inline" form of :class:`.Index` and -establish it as part of ``__table_args__``:: - - class MyMixin(object): - a = Column(Integer) - b = Column(Integer) - - @declared_attr - def __table_args__(cls): - return (Index('test_idx_%s' % cls.__tablename__, 'a', 'b'),) - - class MyModel(MyMixin, Base): - __tablename__ = 'atable' - c = Column(Integer,primary_key=True) - -Special Directives -================== - -``__declare_last__()`` -~~~~~~~~~~~~~~~~~~~~~~ - -The ``__declare_last__()`` hook allows definition of -a class level function that is automatically called by the -:meth:`.MapperEvents.after_configured` event, which occurs after mappings are -assumed to be completed and the 'configure' step has finished:: - - class MyClass(Base): - @classmethod - def __declare_last__(cls): - "" - # do something with mappings - -.. versionadded:: 0.7.3 - -``__declare_first__()`` -~~~~~~~~~~~~~~~~~~~~~~~ - -Like ``__declare_last__()``, but is called at the beginning of mapper -configuration via the :meth:`.MapperEvents.before_configured` event:: - - class MyClass(Base): - @classmethod - def __declare_first__(cls): - "" - # do something before mappings are configured - -.. versionadded:: 0.9.3 - -.. _declarative_abstract: - -``__abstract__`` -~~~~~~~~~~~~~~~~~~~ - -``__abstract__`` causes declarative to skip the production -of a table or mapper for the class entirely. A class can be added within a -hierarchy in the same way as mixin (see :ref:`declarative_mixins`), allowing -subclasses to extend just from the special class:: - - class SomeAbstractBase(Base): - __abstract__ = True - - def some_helpful_method(self): - "" - - @declared_attr - def __mapper_args__(cls): - return {"helpful mapper arguments":True} - - class MyMappedClass(SomeAbstractBase): - "" - -One possible use of ``__abstract__`` is to use a distinct -:class:`.MetaData` for different bases:: - - Base = declarative_base() - - class DefaultBase(Base): - __abstract__ = True - metadata = MetaData() - - class OtherBase(Base): - __abstract__ = True - metadata = MetaData() - -Above, classes which inherit from ``DefaultBase`` will use one -:class:`.MetaData` as the registry of tables, and those which inherit from -``OtherBase`` will use a different one. The tables themselves can then be -created perhaps within distinct databases:: - - DefaultBase.metadata.create_all(some_engine) - OtherBase.metadata_create_all(some_other_engine) - -.. versionadded:: 0.7.3 - -Class Constructor -================= - -As a convenience feature, the :func:`declarative_base` sets a default -constructor on classes which takes keyword arguments, and assigns them -to the named attributes:: - - e = Engineer(primary_language='python') - -Sessions -======== - -Note that ``declarative`` does nothing special with sessions, and is -only intended as an easier way to configure mappers and -:class:`~sqlalchemy.schema.Table` objects. A typical application -setup using :class:`~sqlalchemy.orm.scoping.scoped_session` might look like:: - - engine = create_engine('postgresql://scott:tiger@localhost/test') - Session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) - Base = declarative_base() - -Mapped instances then make usage of -:class:`~sqlalchemy.orm.session.Session` in the usual way. - -""" - from .api import declarative_base, synonym_for, comparable_using, \ instrument_declarative, ConcreteBase, AbstractConcreteBase, \ DeclarativeMeta, DeferredReflection, has_inherited_table,\ @@ -1384,5 +13,6 @@ from .api import declarative_base, synonym_for, comparable_using, \ __all__ = ['declarative_base', 'synonym_for', 'has_inherited_table', 'comparable_using', 'instrument_declarative', 'declared_attr', + 'as_declarative', 'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta', 'DeferredReflection'] diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index 66fe05fd0..713ea0aba 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -1,5 +1,5 @@ # ext/declarative/api.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -175,15 +175,12 @@ class declared_attr(interfaces._MappedAttribute, property): "non-mapped class %s" % (desc.fget.__name__, cls.__name__)) return desc.fget(cls) - try: - reg = manager.info['declared_attr_reg'] - except KeyError: - raise exc.InvalidRequestError( - "@declared_attr called outside of the " - "declarative mapping process; is declarative_base() being " - "used correctly?") - - if desc in reg: + + reg = manager.info.get('declared_attr_reg', None) + + if reg is None: + return desc.fget(cls) + elif desc in reg: return reg[desc] else: reg[desc] = obj = desc.fget(cls) diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index 291608b6c..7d4020b24 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -1,5 +1,5 @@ # ext/declarative/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -35,6 +35,21 @@ def _declared_mapping_info(cls): return None +def _resolve_for_abstract(cls): + if cls is object: + return None + + if _get_immediate_cls_attr(cls, '__abstract__'): + for sup in cls.__bases__: + sup = _resolve_for_abstract(sup) + if sup is not None: + return sup + else: + return None + else: + return cls + + def _get_immediate_cls_attr(cls, attrname): """return an attribute of the class that is either present directly on the class, e.g. not on a superclass, or is from a superclass but @@ -46,6 +61,9 @@ def _get_immediate_cls_attr(cls, attrname): inherit from. """ + if not issubclass(cls, object): + return None + for base in cls.__mro__: _is_declarative_inherits = hasattr(base, '_decl_class_registry') if attrname in base.__dict__: @@ -202,6 +220,7 @@ class _MapperConfig(object): if not oldclassprop and obj._cascading: dict_[name] = column_copies[obj] = \ ret = obj.__get__(obj, cls) + setattr(cls, name, ret) else: if oldclassprop: util.warn_deprecated( @@ -278,7 +297,7 @@ class _MapperConfig(object): elif not isinstance(value, (Column, MapperProperty)): # using @declared_attr for some object that # isn't Column/MapperProperty; remove from the dict_ - # and place the evaulated value onto the class. + # and place the evaluated value onto the class. if not k.startswith('__'): dict_.pop(k) setattr(cls, k, value) @@ -388,6 +407,9 @@ class _MapperConfig(object): table_args = self.table_args declared_columns = self.declared_columns for c in cls.__bases__: + c = _resolve_for_abstract(c) + if c is None: + continue if _declared_mapping_info(c) is not None and \ not _get_immediate_cls_attr( c, '_sa_decl_prepare_nocascade'): @@ -439,6 +461,7 @@ class _MapperConfig(object): def _prepare_mapper_arguments(self): properties = self.properties + if self.mapper_args_fn: mapper_args = self.mapper_args_fn() else: diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 3ef63a5ae..c3887d6cf 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -1,5 +1,5 @@ # ext/declarative/clsregistry.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -71,6 +71,8 @@ class _MultipleClassMarker(object): """ + __slots__ = 'on_remove', 'contents', '__weakref__' + def __init__(self, classes, on_remove=None): self.on_remove = on_remove self.contents = set([ @@ -127,6 +129,8 @@ class _ModuleMarker(object): """ + __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__' + def __init__(self, name, parent): self.parent = parent self.name = name @@ -172,6 +176,8 @@ class _ModuleMarker(object): class _ModNS(object): + __slots__ = '__parent', + def __init__(self, parent): self.__parent = parent @@ -193,6 +199,8 @@ class _ModNS(object): class _GetColumns(object): + __slots__ = 'cls', + def __init__(self, cls): self.cls = cls @@ -221,6 +229,8 @@ inspection._inspects(_GetColumns)( class _GetTable(object): + __slots__ = 'key', 'metadata' + def __init__(self, key, metadata): self.key = key self.metadata = metadata diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index d311fb2d4..c9fb0b044 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -1,5 +1,5 @@ # ext/horizontal_shard.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index e2739d1de..f94c2079e 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -1,5 +1,5 @@ # ext/hybrid.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -145,7 +145,7 @@ usage of the absolute value function:: return func.abs(cls.length) / 2 Above the Python function ``abs()`` is used for instance-level -operations, the SQL function ``ABS()`` is used via the :attr:`.func` +operations, the SQL function ``ABS()`` is used via the :data:`.func` object for class-level expressions:: >>> i1.radius @@ -660,7 +660,7 @@ HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY') """ -class hybrid_method(interfaces.InspectionAttr): +class hybrid_method(interfaces.InspectionAttrInfo): """A decorator which allows definition of a Python object method with both instance-level and class-level behavior. @@ -703,7 +703,7 @@ class hybrid_method(interfaces.InspectionAttr): return self -class hybrid_property(interfaces.InspectionAttr): +class hybrid_property(interfaces.InspectionAttrInfo): """A decorator which allows definition of a Python descriptor with both instance-level and class-level behavior. diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index e49e9ea8b..24fc37a42 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -1,5 +1,5 @@ # ext/mutable.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 61155731c..ac31c7cf7 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -1,5 +1,5 @@ # ext/orderinglist.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index bf8d67d8e..555f3760b 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -1,5 +1,5 @@ # ext/serializer.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index ab9f2ae38..a4738cc61 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -1,5 +1,5 @@ # sqlalchemy/inspect.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py index ae11d1930..717e99b5e 100644 --- a/lib/sqlalchemy/interfaces.py +++ b/lib/sqlalchemy/interfaces.py @@ -1,5 +1,5 @@ # sqlalchemy/interfaces.py -# Copyright (C) 2007-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2007-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # Copyright (C) 2007 Jason Kirtland jek@discorporate.us # diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index b3c9ae024..c23412e38 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -1,5 +1,5 @@ # sqlalchemy/log.py -# Copyright (C) 2006-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2006-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk # diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 741e79b9d..e02a271e3 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -1,5 +1,5 @@ # orm/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 2b4c3ec75..41803c8bf 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1,5 +1,5 @@ # orm/attributes.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -345,18 +345,16 @@ class Event(object): .. versionadded:: 0.9.0 - """ - - impl = None - """The :class:`.AttributeImpl` which is the current event initiator. - """ + :var impl: The :class:`.AttributeImpl` which is the current event + initiator. - op = None - """The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE` or :attr:`.OP_REPLACE`, - indicating the source operation. + :var op: The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE` or + :attr:`.OP_REPLACE`, indicating the source operation. """ + __slots__ = 'impl', 'op', 'parent_token' + def __init__(self, attribute_impl, op): self.impl = attribute_impl self.op = op @@ -455,6 +453,11 @@ class AttributeImpl(object): self.expire_missing = expire_missing + __slots__ = ( + 'class_', 'key', 'callable_', 'dispatch', 'trackparent', + 'parent_token', 'send_modified_events', 'is_equal', 'expire_missing' + ) + def __str__(self): return "%s.%s" % (self.class_.__name__, self.key) @@ -524,23 +527,6 @@ class AttributeImpl(object): state.parents[id_] = False - def set_callable(self, state, callable_): - """Set a callable function for this attribute on the given object. - - This callable will be executed when the attribute is next - accessed, and is assumed to construct part of the instances - previously stored state. When its value or values are loaded, - they will be established as part of the instance's *committed - state*. While *trackparent* information will be assembled for - these instances, attribute-level event handlers will not be - fired. - - The callable overrides the class level callable set in the - ``InstrumentedAttribute`` constructor. - - """ - state.callables[self.key] = callable_ - def get_history(self, state, dict_, passive=PASSIVE_OFF): raise NotImplementedError() @@ -583,7 +569,9 @@ class AttributeImpl(object): if not passive & CALLABLES_OK: return PASSIVE_NO_RESULT - if key in state.callables: + if key in state.expired_attributes: + value = state._load_expired(state, passive) + elif key in state.callables: callable_ = state.callables[key] value = callable_(state, passive) elif self.callable_: @@ -654,6 +642,23 @@ class ScalarAttributeImpl(AttributeImpl): supports_population = True collection = False + __slots__ = '_replace_token', '_append_token', '_remove_token' + + def __init__(self, *arg, **kw): + super(ScalarAttributeImpl, self).__init__(*arg, **kw) + self._replace_token = self._append_token = None + self._remove_token = None + + def _init_append_token(self): + self._replace_token = self._append_token = Event(self, OP_REPLACE) + return self._replace_token + + _init_append_or_replace_token = _init_append_token + + def _init_remove_token(self): + self._remove_token = Event(self, OP_REMOVE) + return self._remove_token + def delete(self, state, dict_): # TODO: catch key errors, convert to attributeerror? @@ -692,27 +697,18 @@ class ScalarAttributeImpl(AttributeImpl): state._modified_event(dict_, self, old) dict_[self.key] = value - @util.memoized_property - def _replace_token(self): - return Event(self, OP_REPLACE) - - @util.memoized_property - def _append_token(self): - return Event(self, OP_REPLACE) - - @util.memoized_property - def _remove_token(self): - return Event(self, OP_REMOVE) - def fire_replace_event(self, state, dict_, value, previous, initiator): for fn in self.dispatch.set: value = fn( - state, value, previous, initiator or self._replace_token) + state, value, previous, + initiator or self._replace_token or + self._init_append_or_replace_token()) return value def fire_remove_event(self, state, dict_, value, initiator): for fn in self.dispatch.remove: - fn(state, value, initiator or self._remove_token) + fn(state, value, + initiator or self._remove_token or self._init_remove_token()) @property def type(self): @@ -732,9 +728,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): supports_population = True collection = False + __slots__ = () + def delete(self, state, dict_): old = self.get(state, dict_) - self.fire_remove_event(state, dict_, old, self._remove_token) + self.fire_remove_event( + state, dict_, old, + self._remove_token or self._init_remove_token()) del dict_[self.key] def get_history(self, state, dict_, passive=PASSIVE_OFF): @@ -807,7 +807,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): self.sethasparent(instance_state(value), state, False) for fn in self.dispatch.remove: - fn(state, value, initiator or self._remove_token) + fn(state, value, initiator or + self._remove_token or self._init_remove_token()) state._modified_event(dict_, self, value) @@ -819,7 +820,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for fn in self.dispatch.set: value = fn( - state, value, previous, initiator or self._replace_token) + state, value, previous, initiator or + self._replace_token or self._init_append_or_replace_token()) state._modified_event(dict_, self, previous) @@ -846,6 +848,8 @@ class CollectionAttributeImpl(AttributeImpl): supports_population = True collection = True + __slots__ = 'copy', 'collection_factory', '_append_token', '_remove_token' + def __init__(self, class_, key, callable_, dispatch, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): @@ -862,6 +866,8 @@ class CollectionAttributeImpl(AttributeImpl): copy_function = self.__copy self.copy = copy_function self.collection_factory = typecallable + self._append_token = None + self._remove_token = None if getattr(self.collection_factory, "_sa_linker", None): @@ -873,6 +879,14 @@ class CollectionAttributeImpl(AttributeImpl): def unlink(target, collection, collection_adapter): collection._sa_linker(None) + def _init_append_token(self): + self._append_token = Event(self, OP_APPEND) + return self._append_token + + def _init_remove_token(self): + self._remove_token = Event(self, OP_REMOVE) + return self._remove_token + def __copy(self, item): return [y for y in collections.collection_adapter(item)] @@ -915,17 +929,11 @@ class CollectionAttributeImpl(AttributeImpl): return [(instance_state(o), o) for o in current] - @util.memoized_property - def _append_token(self): - return Event(self, OP_APPEND) - - @util.memoized_property - def _remove_token(self): - return Event(self, OP_REMOVE) - def fire_append_event(self, state, dict_, value, initiator): for fn in self.dispatch.append: - value = fn(state, value, initiator or self._append_token) + value = fn( + state, value, + initiator or self._append_token or self._init_append_token()) state._modified_event(dict_, self, NEVER_SET, True) @@ -942,7 +950,8 @@ class CollectionAttributeImpl(AttributeImpl): self.sethasparent(instance_state(value), state, False) for fn in self.dispatch.remove: - fn(state, value, initiator or self._remove_token) + fn(state, value, + initiator or self._remove_token or self._init_remove_token()) state._modified_event(dict_, self, NEVER_SET, True) @@ -1134,7 +1143,8 @@ def backref_listeners(attribute, key, uselist): impl.pop(old_state, old_dict, state.obj(), - parent_impl._append_token, + parent_impl._append_token or + parent_impl._init_append_token(), passive=PASSIVE_NO_FETCH) if child is not None: diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 3390ceec4..c259878f0 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -1,5 +1,5 @@ # orm/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -183,6 +183,10 @@ NOT_EXTENSION = util.symbol( _none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT]) +_SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED") + +_DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") + def _generative(*assertions): """Mark a method as generative, e.g. method-chained.""" @@ -323,10 +327,9 @@ def _is_mapped_class(entity): insp = inspection.inspect(entity, False) return insp is not None and \ - hasattr(insp, "mapper") and \ + not insp.is_clause_element and \ ( - insp.is_mapper - or insp.is_aliased_class + insp.is_mapper or insp.is_aliased_class ) @@ -437,6 +440,7 @@ class InspectionAttr(object): here intact for forwards-compatibility. """ + __slots__ = () is_selectable = False """Return True if this object is an instance of :class:`.Selectable`.""" @@ -488,6 +492,16 @@ class InspectionAttr(object): """ + +class InspectionAttrInfo(InspectionAttr): + """Adds the ``.info`` attribute to :class:`.InspectionAttr`. + + The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo` + is that the former is compatible as a mixin for classes that specify + ``__slots__``; this is essentially an implementation artifact. + + """ + @util.memoized_property def info(self): """Info dictionary associated with the object, allowing user-defined @@ -501,9 +515,10 @@ class InspectionAttr(object): .. versionadded:: 0.8 Added support for .info to all :class:`.MapperProperty` subclasses. - .. versionchanged:: 1.0.0 :attr:`.InspectionAttr.info` moved - from :class:`.MapperProperty` so that it can apply to a wider - variety of ORM and extension constructs. + .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also + available on extension types via the + :attr:`.InspectionAttrInfo.info` attribute, so that it can apply + to a wider variety of ORM and extension constructs. .. seealso:: @@ -520,3 +535,4 @@ class _MappedAttribute(object): attributes. """ + __slots__ = () diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 356a8a3b9..4f988a8d4 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -1,5 +1,5 @@ # orm/collections.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -1507,8 +1507,8 @@ class MappedCollection(dict): def __init__(self, keyfunc): """Create a new collection with keying provided by keyfunc. - keyfunc may be any callable any callable that takes an object and - returns an object for use as a dictionary key. + keyfunc may be any callable that takes an object and returns an object + for use as a dictionary key. The keyfunc will be called every time the ORM needs to add a member by value-only (such as when loading instances from the database) or diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index d10a38394..d8989939b 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -1,5 +1,5 @@ # orm/dependency.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/deprecated_interfaces.py b/lib/sqlalchemy/orm/deprecated_interfaces.py index 275582323..bb6d185d4 100644 --- a/lib/sqlalchemy/orm/deprecated_interfaces.py +++ b/lib/sqlalchemy/orm/deprecated_interfaces.py @@ -1,5 +1,5 @@ # orm/deprecated_interfaces.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 19ff71f73..17c2d28ce 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -1,5 +1,5 @@ # orm/descriptor_props.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -143,6 +143,7 @@ class CompositeProperty(DescriptorProperty): class. **Deprecated.** Please see :class:`.AttributeEvents`. """ + super(CompositeProperty, self).__init__() self.attrs = attrs self.composite_class = class_ @@ -471,6 +472,7 @@ class ConcreteInheritedProperty(DescriptorProperty): return comparator_callable def __init__(self): + super(ConcreteInheritedProperty, self).__init__() def warn(): raise AttributeError("Concrete %s does not implement " "attribute %r at the instance level. Add " @@ -555,6 +557,7 @@ class SynonymProperty(DescriptorProperty): more complicated attribute-wrapping schemes than synonyms. """ + super(SynonymProperty, self).__init__() self.name = name self.map_column = map_column @@ -684,6 +687,7 @@ class ComparableProperty(DescriptorProperty): .. versionadded:: 1.0.0 """ + super(ComparableProperty, self).__init__() self.descriptor = descriptor self.comparator_factory = comparator_factory self.doc = doc or (descriptor and descriptor.__doc__) or None diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index a4ccfe417..aedd863f8 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -1,5 +1,5 @@ # orm/dynamic.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 2026e5d0a..1e828ff86 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -1,5 +1,5 @@ # orm/evaluator.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 9ea0dd834..233cd66a6 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1,5 +1,5 @@ # orm/events.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -17,7 +17,7 @@ from . import mapperlib, instrumentation from .session import Session, sessionmaker from .scoping import scoped_session from .attributes import QueryableAttribute - +from .query import Query class InstrumentationEvents(event.Events): """Events related to class instrumentation events. @@ -1479,8 +1479,9 @@ class AttributeEvents(event.Events): @staticmethod def _set_dispatch(cls, dispatch_cls): - event.Events._set_dispatch(cls, dispatch_cls) + dispatch = event.Events._set_dispatch(cls, dispatch_cls) dispatch_cls._active_history = False + return dispatch @classmethod def _accept_with(cls, target): @@ -1650,3 +1651,56 @@ class AttributeEvents(event.Events): the :class:`.collection.linker` hook. """ + + +class QueryEvents(event.Events): + """Represent events within the construction of a :class:`.Query` object. + + The events here are intended to be used with an as-yet-unreleased + inspection system for :class:`.Query`. Some very basic operations + are possible now, however the inspection system is intended to allow + complex query manipulations to be automated. + + .. versionadded:: 1.0.0 + + """ + + _target_class_doc = "SomeQuery" + _dispatch_target = Query + + def before_compile(self, query): + """Receive the :class:`.Query` object before it is composed into a + core :class:`.Select` object. + + This event is intended to allow changes to the query given:: + + @event.listens_for(Query, "before_compile", retval=True) + def no_deleted(query): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['expr'] + query = query.filter(entity.deleted == False) + return query + + The event should normally be listened with the ``retval=True`` + parameter set, so that the modified query may be returned. + + + """ + + @classmethod + def _listen( + cls, event_key, retval=False, **kw): + fn = event_key._listen_fn + + if not retval: + def wrap(*arg, **kw): + if not retval: + query = arg[0] + fn(*arg, **kw) + return query + else: + return fn(*arg, **kw) + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(**kw) diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index ff0ece411..e010a295d 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -1,5 +1,5 @@ # orm/exc.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 24dd47859..46be2b719 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -1,5 +1,5 @@ # orm/identity.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -44,7 +44,8 @@ class IdentityMap(object): def _manage_removed_state(self, state): del state._instance_dict - self._modified.discard(state) + if state.modified: + self._modified.discard(state) def _dirty_states(self): return self._modified @@ -186,6 +187,9 @@ class WeakInstanceDict(IdentityMap): else: return list(self._dict.values()) + def _fast_discard(self, state): + self._dict.pop(state.key, None) + def discard(self, state): st = self._dict.pop(state.key, None) if st: @@ -264,6 +268,9 @@ class StrongInstanceDict(IdentityMap): self._dict[key] = state.obj() state._instance_dict = self._wr + def _fast_discard(self, state): + self._dict.pop(state.key, None) + def discard(self, state): obj = self._dict.pop(state.key, None) if obj is not None: diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index ad7d2d53d..be2fe91c2 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -1,5 +1,5 @@ # orm/instrumentation.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -35,6 +35,9 @@ from .. import util from . import base +_memoized_key_collection = util.group_expirable_memoized_property() + + class ClassManager(dict): """tracks state information at the class level.""" @@ -92,6 +95,21 @@ class ClassManager(dict): def is_mapped(self): return 'mapper' in self.__dict__ + @_memoized_key_collection + def _all_key_set(self): + return frozenset(self) + + @_memoized_key_collection + def _collection_impl_keys(self): + return frozenset([ + attr.key for attr in self.values() if attr.impl.collection]) + + @_memoized_key_collection + def _scalar_loader_impls(self): + return frozenset([ + attr.impl for attr in + self.values() if attr.impl.accepts_scalar_loader]) + @util.memoized_property def mapper(self): # raises unless self.mapper has been assigned @@ -195,6 +213,7 @@ class ClassManager(dict): else: self.local_attrs[key] = inst self.install_descriptor(key, inst) + _memoized_key_collection.expire_instance(self) self[key] = inst for cls in self.class_.__subclasses__(): @@ -223,6 +242,7 @@ class ClassManager(dict): else: del self.local_attrs[key] self.uninstall_descriptor(key) + _memoized_key_collection.expire_instance(self) del self[key] for cls in self.class_.__subclasses__(): manager = manager_of_class(cls) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index ad2452c1b..6cc613baa 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -1,5 +1,5 @@ # orm/interfaces.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -24,8 +24,10 @@ from .. import util from ..sql import operators from .base import (ONETOMANY, MANYTOONE, MANYTOMANY, EXT_CONTINUE, EXT_STOP, NOT_EXTENSION) -from .base import InspectionAttr, _MappedAttribute +from .base import (InspectionAttr, InspectionAttr, + InspectionAttrInfo, _MappedAttribute) import collections +from .. import inspect # imported later MapperExtension = SessionExtension = AttributeExtension = None @@ -48,11 +50,8 @@ __all__ = ( ) -class MapperProperty(_MappedAttribute, InspectionAttr): - """Manage the relationship of a ``Mapper`` to a single class - attribute, as well as that attribute as it appears on individual - instances of the class, including attribute instrumentation, - attribute access, loading behavior, and dependency calculations. +class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): + """Represent a particular class attribute mapped by :class:`.Mapper`. The most common occurrences of :class:`.MapperProperty` are the mapped :class:`.Column`, which is represented in a mapping as @@ -63,6 +62,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr): """ + __slots__ = ( + '_configure_started', '_configure_finished', 'parent', 'key', + 'info' + ) + cascade = frozenset() """The set of 'cascade' attribute names. @@ -78,6 +82,32 @@ class MapperProperty(_MappedAttribute, InspectionAttr): """ + def _memoized_attr_info(self): + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`.relationship`, or :func:`.composite` + functions. + + .. versionadded:: 0.8 Added support for .info to all + :class:`.MapperProperty` subclasses. + + .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also + available on extension types via the + :attr:`.InspectionAttrInfo.info` attribute, so that it can apply + to a wider variety of ORM and extension constructs. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + return {} + def setup(self, context, entity, path, adapter, **kwargs): """Called by Query for the purposes of constructing a SQL statement. @@ -139,8 +169,9 @@ class MapperProperty(_MappedAttribute, InspectionAttr): """ - _configure_started = False - _configure_finished = False + def __init__(self): + self._configure_started = False + self._configure_finished = False def init(self): """Called after all mappers are created to assemble @@ -303,9 +334,11 @@ class PropComparator(operators.ColumnOperators): """ + __slots__ = 'prop', 'property', '_parententity', '_adapt_to_entity' + def __init__(self, prop, parentmapper, adapt_to_entity=None): self.prop = self.property = prop - self._parentmapper = parentmapper + self._parententity = parentmapper self._adapt_to_entity = adapt_to_entity def __clause_element__(self): @@ -318,7 +351,13 @@ class PropComparator(operators.ColumnOperators): """Return a copy of this PropComparator which will use the given :class:`.AliasedInsp` to produce corresponding expressions. """ - return self.__class__(self.prop, self._parentmapper, adapt_to_entity) + return self.__class__(self.prop, self._parententity, adapt_to_entity) + + @property + def _parentmapper(self): + """legacy; this is renamed to _parententity to be + compatible with QueryableAttribute.""" + return inspect(self._parententity).mapper @property def adapter(self): @@ -331,7 +370,7 @@ class PropComparator(operators.ColumnOperators): else: return self._adapt_to_entity._adapt_element - @util.memoized_property + @property def info(self): return self.property.info @@ -420,6 +459,8 @@ class StrategizedProperty(MapperProperty): """ + __slots__ = '_strategies', 'strategy' + strategy_wildcard_key = None def _get_context_loader(self, context, path): @@ -454,7 +495,8 @@ class StrategizedProperty(MapperProperty): def _get_strategy_by_cls(self, cls): return self._get_strategy(cls._strategy_keys[0]) - def setup(self, context, entity, path, adapter, **kwargs): + def setup( + self, context, entity, path, adapter, **kwargs): loader = self._get_context_loader(context, path) if loader and loader.strategy: strat = self._get_strategy(loader.strategy) @@ -483,14 +525,17 @@ class StrategizedProperty(MapperProperty): not mapper.class_manager._attr_has_impl(self.key): self.strategy.init_class_attribute(mapper) - _strategies = collections.defaultdict(dict) + _all_strategies = collections.defaultdict(dict) @classmethod def strategy_for(cls, **kw): def decorate(dec_cls): - dec_cls._strategy_keys = [] + # ensure each subclass of the strategy has its + # own _strategy_keys collection + if '_strategy_keys' not in dec_cls.__dict__: + dec_cls._strategy_keys = [] key = tuple(sorted(kw.items())) - cls._strategies[cls][key] = dec_cls + cls._all_strategies[cls][key] = dec_cls dec_cls._strategy_keys.append(key) return dec_cls return decorate @@ -498,8 +543,8 @@ class StrategizedProperty(MapperProperty): @classmethod def _strategy_lookup(cls, *key): for prop_cls in cls.__mro__: - if prop_cls in cls._strategies: - strategies = cls._strategies[prop_cls] + if prop_cls in cls._all_strategies: + strategies = cls._all_strategies[prop_cls] try: return strategies[key] except KeyError: @@ -558,6 +603,8 @@ class LoaderStrategy(object): """ + __slots__ = 'parent_property', 'is_class_level', 'parent', 'key' + def __init__(self, parent): self.parent_property = parent self.is_class_level = False diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 380afcdc7..50afaf601 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -1,5 +1,5 @@ # orm/loading.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -18,6 +18,7 @@ from .. import util from . import attributes, exc as orm_exc from ..sql import util as sql_util from .util import _none_set, state_str +from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE from .. import exc as sa_exc import collections @@ -42,41 +43,45 @@ def instances(query, cursor, context): def filter_fn(row): return tuple(fn(x) for x, fn in zip(row, filter_fns)) - (process, labels) = \ - list(zip(*[ - query_entity.row_processor(query, - context, cursor) - for query_entity in query._entities - ])) - - if not single_entity: - keyed_tuple = util.lightweight_named_tuple('result', labels) - - while True: - context.partials = {} - - if query._yield_per: - fetch = cursor.fetchmany(query._yield_per) - if not fetch: - break - else: - fetch = cursor.fetchall() + try: + (process, labels) = \ + list(zip(*[ + query_entity.row_processor(query, + context, cursor) + for query_entity in query._entities + ])) + + if not single_entity: + keyed_tuple = util.lightweight_named_tuple('result', labels) + + while True: + context.partials = {} + + if query._yield_per: + fetch = cursor.fetchmany(query._yield_per) + if not fetch: + break + else: + fetch = cursor.fetchall() - if single_entity: - proc = process[0] - rows = [proc(row) for row in fetch] - else: - rows = [keyed_tuple([proc(row) for proc in process]) - for row in fetch] + if single_entity: + proc = process[0] + rows = [proc(row) for row in fetch] + else: + rows = [keyed_tuple([proc(row) for proc in process]) + for row in fetch] - if filtered: - rows = util.unique_list(rows, filter_fn) + if filtered: + rows = util.unique_list(rows, filter_fn) - for row in rows: - yield row + for row in rows: + yield row - if not query._yield_per: - break + if not query._yield_per: + break + except Exception as err: + cursor.close() + util.raise_from_cause(err) @util.dependencies("sqlalchemy.orm.query") @@ -142,7 +147,7 @@ def get_from_identity(session, key, passive): # expired state will be checked soon enough, if necessary return instance try: - state(state, passive) + state._load_expired(state, passive) except orm_exc.ObjectDeletedError: session._remove_newly_deleted([state]) return None @@ -214,10 +219,56 @@ def load_on_ident(query, key, return None -def instance_processor(mapper, context, result, path, adapter, - only_load_props=None, refresh_state=None, - polymorphic_discriminator=None, - _polymorphic_from=None): +def _setup_entity_query( + context, mapper, query_entity, + path, adapter, column_collection, + with_polymorphic=None, only_load_props=None, + polymorphic_discriminator=None, **kw): + + if with_polymorphic: + poly_properties = mapper._iterate_polymorphic_properties( + with_polymorphic) + else: + poly_properties = mapper._polymorphic_properties + + quick_populators = {} + + path.set( + context.attributes, + "memoized_setups", + quick_populators) + + for value in poly_properties: + if only_load_props and \ + value.key not in only_load_props: + continue + value.setup( + context, + query_entity, + path, + adapter, + only_load_props=only_load_props, + column_collection=column_collection, + memoized_populators=quick_populators, + **kw + ) + + if polymorphic_discriminator is not None and \ + polymorphic_discriminator \ + is not mapper.polymorphic_on: + + if adapter: + pd = adapter.columns[polymorphic_discriminator] + else: + pd = polymorphic_discriminator + column_collection.append(pd) + + +def _instance_processor( + mapper, context, result, path, adapter, + only_load_props=None, refresh_state=None, + polymorphic_discriminator=None, + _polymorphic_from=None): """Produce a mapper level row processor callable which processes rows into mapped instances.""" @@ -236,13 +287,41 @@ def instance_processor(mapper, context, result, path, adapter, populators = collections.defaultdict(list) - props = mapper._props.values() + props = mapper._prop_set if only_load_props is not None: - props = (p for p in props if p.key in only_load_props) + props = props.intersection( + mapper._props[k] for k in only_load_props) + + quick_populators = path.get( + context.attributes, "memoized_setups", _none_set) for prop in props: - prop.create_row_processor( - context, path, mapper, result, adapter, populators) + if prop in quick_populators: + # this is an inlined path just for column-based attributes. + col = quick_populators[prop] + if col is _DEFER_FOR_STATE: + populators["new"].append( + (prop.key, prop._deferred_column_loader)) + elif col is _SET_DEFERRED_EXPIRED: + # note that in this path, we are no longer + # searching in the result to see if the column might + # be present in some unexpected way. + populators["expire"].append((prop.key, False)) + else: + if adapter: + col = adapter.columns[col] + getter = result._getter(col) + if getter: + populators["quick"].append((prop.key, getter)) + else: + # fall back to the ColumnProperty itself, which + # will iterate through all of its columns + # to see if one fits + prop.create_row_processor( + context, path, mapper, result, adapter, populators) + else: + prop.create_row_processor( + context, path, mapper, result, adapter, populators) propagate_options = context.propagate_options if propagate_options: @@ -384,7 +463,7 @@ def instance_processor(mapper, context, result, path, adapter, return instance - if not _polymorphic_from and not refresh_state: + if mapper.polymorphic_map and not _polymorphic_from and not refresh_state: # if we are doing polymorphic, dispatch to a different _instance() # method specific to the subclass mapper _instance = _decorate_polymorphic_switch( @@ -407,11 +486,11 @@ def _populate_full( for key, set_callable in populators["expire"]: dict_.pop(key, None) if set_callable: - state.callables[key] = state + state.expired_attributes.add(key) else: for key, set_callable in populators["expire"]: if set_callable: - state.callables[key] = state + state.expired_attributes.add(key) for key, populator in populators["new"]: populator(state, dict_, row) for key, populator in populators["delayed"]: @@ -441,7 +520,7 @@ def _populate_partial( if key in to_load: dict_.pop(key, None) if set_callable: - state.callables[key] = state + state.expired_attributes.add(key) for key, populator in populators["new"]: if key in to_load: populator(state, dict_, row) @@ -499,7 +578,7 @@ def _decorate_polymorphic_switch( if sub_mapper is mapper: return None - return instance_processor( + return _instance_processor( sub_mapper, context, result, path, adapter, _polymorphic_from=mapper) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7e88ba161..4554f78f9 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1,5 +1,5 @@ # orm/mapper.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -974,6 +974,15 @@ class Mapper(InspectionAttr): self._all_tables = self.inherits._all_tables if self.polymorphic_identity is not None: + if self.polymorphic_identity in self.polymorphic_map: + util.warn( + "Reassigning polymorphic association for identity %r " + "from %r to %r: Check for duplicate use of %r as " + "value for polymorphic_identity." % + (self.polymorphic_identity, + self.polymorphic_map[self.polymorphic_identity], + self, self.polymorphic_identity) + ) self.polymorphic_map[self.polymorphic_identity] = self else: @@ -1248,7 +1257,7 @@ class Mapper(InspectionAttr): self._readonly_props = set( self._columntoproperty[col] for col in self._columntoproperty - if self._columntoproperty[col] not in self._primary_key_props and + if self._columntoproperty[col] not in self._identity_key_props and (not hasattr(col, 'table') or col.table not in self._cols_by_table)) @@ -1492,6 +1501,10 @@ class Mapper(InspectionAttr): return identities + @_memoized_configured_property + def _prop_set(self): + return frozenset(self._props.values()) + def _adapt_inherited_property(self, key, prop, init): if not self.concrete: self._configure_property(key, prop, init=False, setparent=False) @@ -1581,6 +1594,8 @@ class Mapper(InspectionAttr): self, prop, )) + oldprop = self._props[key] + self._path_registry.pop(oldprop, None) self._props[key] = prop @@ -2371,16 +2386,31 @@ class Mapper(InspectionAttr): manager[prop.key]. impl.get(state, dict_, attributes.PASSIVE_RETURN_NEVER_SET) - for prop in self._primary_key_props + for prop in self._identity_key_props ] @_memoized_configured_property - def _primary_key_props(self): - # TODO: this should really be called "identity key props", - # as it does not necessarily include primary key columns within - # individual tables + def _identity_key_props(self): return [self._columntoproperty[col] for col in self.primary_key] + @_memoized_configured_property + def _all_pk_props(self): + collection = set() + for table in self.tables: + collection.update(self._pks_by_table[table]) + return collection + + @_memoized_configured_property + def _should_undefer_in_wildcard(self): + cols = set(self.primary_key) + if self.polymorphic_on is not None: + cols.add(self.polymorphic_on) + return cols + + @_memoized_configured_property + def _primary_key_propkeys(self): + return set([prop.key for prop in self._all_pk_props]) + def _get_state_attr_by_column( self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET): @@ -2633,7 +2663,7 @@ def configure_mappers(): if not Mapper._new_mappers: return - Mapper.dispatch(Mapper).before_configured() + Mapper.dispatch._for_class(Mapper).before_configured() # initialize properties on all mappers # note that _mapper_registry is unordered, which # may randomly conceal/reveal issues related to @@ -2665,7 +2695,7 @@ def configure_mappers(): _already_compiling = False finally: _CONFIGURE_MUTEX.release() - Mapper.dispatch(Mapper).after_configured() + Mapper.dispatch._for_class(Mapper).after_configured() def reconstructor(fn): @@ -2777,6 +2807,8 @@ def _event_on_init(state, args, kwargs): class _ColumnMapping(dict): """Error reporting helper for mapper._columntoproperty.""" + __slots__ = 'mapper', + def __init__(self, mapper): self.mapper = mapper diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index f10a125a8..9670a07fb 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -1,5 +1,5 @@ # orm/path_registry.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -13,6 +13,9 @@ from .. import util from .. import exc from itertools import chain from .base import class_mapper +import logging + +log = logging.getLogger(__name__) def _unreduce_path(path): @@ -49,14 +52,19 @@ class PathRegistry(object): """ + is_token = False + is_root = False + def __eq__(self, other): return other is not None and \ self.path == other.path def set(self, attributes, key, value): + log.debug("set '%s' on path '%s' to '%s'", key, self, value) attributes[(key, self.path)] = value def setdefault(self, attributes, key, value): + log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value) attributes.setdefault((key, self.path), value) def get(self, attributes, key, value=None): @@ -148,6 +156,8 @@ class RootRegistry(PathRegistry): """ path = () has_entity = False + is_aliased_class = False + is_root = True def __getitem__(self, entity): return entity._path_registry @@ -163,6 +173,15 @@ class TokenRegistry(PathRegistry): has_entity = False + is_token = True + + def generate_for_superclasses(self): + if not self.parent.is_aliased_class and not self.parent.is_root: + for ent in self.parent.mapper.iterate_to_root(): + yield TokenRegistry(self.parent.parent[ent], self.token) + else: + yield self + def __getitem__(self, entity): raise NotImplementedError() @@ -184,6 +203,11 @@ class PropRegistry(PathRegistry): self.parent = parent self.path = parent.path + (prop,) + def __str__(self): + return " -> ".join( + str(elem) for elem in self.path + ) + @util.memoized_property def has_entity(self): return hasattr(self.prop, "mapper") diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 114b79ea5..ff5dda7b3 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1,5 +1,5 @@ # orm/persistence.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -15,15 +15,114 @@ in unitofwork.py. """ import operator -from itertools import groupby -from .. import sql, util, exc as sa_exc, schema +from itertools import groupby, chain +from .. import sql, util, exc as sa_exc from . import attributes, sync, exc as orm_exc, evaluator from .base import state_str, _attr_as_key, _entity_descriptor from ..sql import expression +from ..sql.base import _from_objects from . import loading -def save_obj(base_mapper, states, uowtransaction, single=False): +def _bulk_insert( + mapper, mappings, session_transaction, isstates, return_defaults): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_insert()") + + if isstates: + if return_defaults: + states = [(state, state.dict) for state in mappings] + mappings = [dict_ for (state, dict_) in states] + else: + mappings = [state.dict for state in mappings] + else: + mappings = list(mappings) + + connection = session_transaction.connection(base_mapper) + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper): + continue + + records = ( + (None, state_dict, params, mapper, + connection, value_params, has_all_pks, has_all_defaults) + for + state, state_dict, params, mp, + conn, value_params, has_all_pks, + has_all_defaults in _collect_insert_commands(table, ( + (None, mapping, mapper, connection) + for mapping in mappings), + bulk=True, return_defaults=return_defaults + ) + ) + _emit_insert_statements(base_mapper, None, + cached_connections, + super_mapper, table, records, + bookkeeping=return_defaults) + + if return_defaults and isstates: + identity_cls = mapper._identity_class + identity_props = [p.key for p in mapper._identity_key_props] + for state, dict_ in states: + state.key = ( + identity_cls, + tuple([dict_[key] for key in identity_props]) + ) + + +def _bulk_update(mapper, mappings, session_transaction, + isstates, update_changed_only): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + def _changed_dict(mapper, state): + return dict( + (k, v) + for k, v in state.dict.items() if k in state.committed_state or k + in mapper._primary_key_propkeys + ) + + if isstates: + if update_changed_only: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = [state.dict for state in mappings] + else: + mappings = list(mappings) + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_update()") + + connection = session_transaction.connection(base_mapper) + + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper): + continue + + records = _collect_update_commands(None, table, ( + (None, mapping, mapper, connection, + (mapping[mapper._version_id_prop.key] + if mapper._version_id_prop else None)) + for mapping in mappings + ), bulk=True) + + _emit_update_statements(base_mapper, None, + cached_connections, + super_mapper, table, records, + bookkeeping=False) + + +def save_obj( + base_mapper, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -76,17 +175,16 @@ def save_obj(base_mapper, states, uowtransaction, single=False): _finalize_insert_update_commands( base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, False) - for state, state_dict, mapper, connection in states_to_insert - ) - ) - _finalize_insert_update_commands( - base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, True) - for state, state_dict, mapper, connection, - update_version_id in states_to_update + chain( + ( + (state, state_dict, mapper, connection, False) + for state, state_dict, mapper, connection in states_to_insert + ), + ( + (state, state_dict, mapper, connection, True) + for state, state_dict, mapper, connection, + update_version_id in states_to_update + ) ) ) @@ -261,7 +359,9 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): state, dict_, mapper, connection, update_version_id) -def _collect_insert_commands(table, states_to_insert): +def _collect_insert_commands( + table, states_to_insert, + bulk=False, return_defaults=False): """Identify sets of values to use in INSERT statements for a list of states. @@ -280,22 +380,26 @@ def _collect_insert_commands(table, states_to_insert): col = propkey_to_col[propkey] if value is None: continue - elif isinstance(value, sql.ClauseElement): + elif not bulk and isinstance(value, sql.ClauseElement): value_params[col.key] = value else: params[col.key] = value - for colkey in mapper._insert_cols_as_none[table].\ - difference(params).difference(value_params): - params[colkey] = None + if not bulk: + for colkey in mapper._insert_cols_as_none[table].\ + difference(params).difference(value_params): + params[colkey] = None - has_all_pks = mapper._pk_keys_by_table[table].issubset(params) + if not bulk or return_defaults: + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) - if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].\ - issubset(params) + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True else: - has_all_defaults = True + has_all_defaults = has_all_pks = True if mapper.version_id_generator is not False \ and mapper.version_id_col is not None and \ @@ -309,7 +413,9 @@ def _collect_insert_commands(table, states_to_insert): has_all_defaults) -def _collect_update_commands(uowtransaction, table, states_to_update): +def _collect_update_commands( + uowtransaction, table, states_to_update, + bulk=False): """Identify sets of values to use in UPDATE statements for a list of states. @@ -329,23 +435,32 @@ def _collect_update_commands(uowtransaction, table, states_to_update): pks = mapper._pks_by_table[table] - params = {} value_params = {} propkey_to_col = mapper._propkey_to_col[table] - for propkey in set(propkey_to_col).intersection(state.committed_state): - value = state_dict[propkey] - col = propkey_to_col[propkey] - - if not state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey]): - if isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value + if bulk: + params = dict( + (propkey_to_col[propkey].key, state_dict[propkey]) + for propkey in + set(propkey_to_col).intersection(state_dict) + ) + else: + params = {} + for propkey in set(propkey_to_col).intersection( + state.committed_state): + value = state_dict[propkey] + col = propkey_to_col[propkey] + + if not state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey]): + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value - if update_version_id is not None: + if update_version_id is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: col = mapper.version_id_col params[col._label] = update_version_id @@ -357,30 +472,39 @@ def _collect_update_commands(uowtransaction, table, states_to_update): if not (params or value_params): continue - pk_params = {} - for col in pks: - propkey = mapper._columntoproperty[col].key - history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF) - - if history.added: - if not history.deleted or \ - ("pk_cascaded", state, col) in \ - uowtransaction.attributes: - pk_params[col._label] = history.added[0] - params.pop(col.key, None) + if bulk: + pk_params = dict( + (propkey_to_col[propkey]._label, state_dict.get(propkey)) + for propkey in + set(propkey_to_col). + intersection(mapper._pk_keys_by_table[table]) + ) + else: + pk_params = {} + for col in pks: + propkey = mapper._columntoproperty[col].key + + history = state.manager[propkey].impl.get_history( + state, state_dict, attributes.PASSIVE_OFF) + + if history.added: + if not history.deleted or \ + ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + pk_params[col._label] = history.added[0] + params.pop(col.key, None) + else: + # else, use the old value to locate the row + pk_params[col._label] = history.deleted[0] + params[col.key] = history.added[0] else: - # else, use the old value to locate the row - pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] - else: - pk_params[col._label] = history.unchanged[0] + pk_params[col._label] = history.unchanged[0] + if pk_params[col._label] is None: + raise orm_exc.FlushError( + "Can't update table %s using NULL for primary " + "key value on column %s" % (table, col)) if params or value_params: - if None in pk_params.values(): - raise orm_exc.FlushError( - "Can't update table using NULL for primary " - "key value") params.update(pk_params) yield ( state, state_dict, params, mapper, @@ -441,23 +565,24 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, state, state_dict, col) if value is None: raise orm_exc.FlushError( - "Can't delete from table " + "Can't delete from table %s " "using NULL for primary " - "key value") + "key value on column %s" % (table, col)) if update_version_id is not None and \ - table.c.contains_column(mapper.version_id_col): + mapper.version_id_col in mapper._cols_by_table[table]: params[mapper.version_id_col.key] = update_version_id yield params, connection def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): + cached_connections, mapper, table, update, + bookkeeping=True): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def update_stmt(): clause = sql.and_() @@ -486,32 +611,42 @@ def _emit_update_statements(base_mapper, uowtransaction, records in groupby( update, lambda rec: ( - rec[4], - tuple(sorted(rec[2])), - bool(rec[5]))): + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]))): # whether or not we have "value" parameters rows = 0 records = list(records) + + # TODO: would be super-nice to not have to determine this boolean + # inside the loop here, in the 99.9999% of the time there's only + # one connection in use + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = not needs_version_id or assert_multirow + if hasvalue: for state, state_dict, params, mapper, \ connection, value_params in records: c = connection.execute( statement.values(value_params), params) - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) rows += c.rowcount + check_rowcount = True else: - if needs_version_id and \ - not connection.dialect.supports_sane_multi_rowcount and \ - connection.dialect.supports_sane_rowcount: + if not allow_multirow: + check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ connection, value_params in records: c = cached_connections[connection].\ @@ -528,6 +663,12 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount else: multiparams = [rec[2] for rec in records] + + check_rowcount = assert_multirow or ( + assert_singlerow and + len(multiparams) == 1 + ) + c = cached_connections[connection].\ execute(statement, multiparams) @@ -544,7 +685,7 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) - if connection.dialect.supports_sane_rowcount: + if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " @@ -558,20 +699,23 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert): + cached_connections, mapper, table, insert, + bookkeeping=True): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" statement = base_mapper._memo(('insert', table), table.insert) for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ - records in groupby(insert, - lambda rec: (rec[4], - tuple(sorted(rec[2].keys())), - bool(rec[5]), - rec[6], rec[7]) - ): - if \ + records in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7])): + if not bookkeeping or \ ( has_all_defaults or not base_mapper.eager_defaults @@ -584,19 +728,20 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + if bookkeeping: + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -657,7 +802,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # also group them into common (connection, cols) sets # to support executemany(). for key, grouper in groupby( - update, lambda rec: (rec[1], sorted(rec[0])) + update, lambda rec: ( + rec[1], # connection + set(rec[0]) # parameter keys + ) ): connection = key[0] multiparams = [params for params, conn in grouper] @@ -671,7 +819,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, by _collect_delete_commands().""" need_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def delete_stmt(): clause = sql.and_() @@ -693,12 +841,9 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, statement = base_mapper._memo(('delete', table), delete_stmt) for connection, recs in groupby( delete, - lambda rec: rec[1] + lambda rec: rec[1] # connection ): - del_objects = [ - params - for params, connection in recs - ] + del_objects = [params for params, connection in recs] connection = cached_connections[connection] @@ -775,9 +920,8 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): toload_now.extend(state._unloaded_non_object) elif mapper.version_id_col is not None and \ mapper.version_id_generator is False: - prop = mapper._columntoproperty[mapper.version_id_col] - if prop.key in state.unloaded: - toload_now.extend([prop.key]) + if mapper._version_id_prop.key in state.unloaded: + toload_now.extend([mapper._version_id_prop.key]) if toload_now: state.key = base_mapper._identity_key_from_state(state) @@ -794,7 +938,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params): + state, dict_, result, params, value_params, bulk=False): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -803,7 +947,8 @@ def _postfetch(mapper, uowtransaction, table, postfetch_cols = result.context.compiled.postfetch returning_cols = result.context.compiled.returning - if mapper.version_id_col is not None: + if mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] if returning_cols: @@ -829,10 +974,13 @@ def _postfetch(mapper, uowtransaction, table, # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + if state is None: + sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) + else: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) def _connections_for_states(base_mapper, uowtransaction, states): @@ -883,6 +1031,27 @@ class BulkUD(object): def __init__(self, query): self.query = query.enable_eagerloads(False) + self.mapper = self.query._bind_mapper() + self._validate_query_state() + + def _validate_query_state(self): + for attr, methname, notset in ( + ('_limit', 'limit()', None), + ('_offset', 'offset()', None), + ('_order_by', 'order_by()', False), + ('_group_by', 'group_by()', False), + ('_distinct', 'distinct()', False), + ( + '_from_obj', + 'join(), outerjoin(), select_from(), or from_self()', + ()) + ): + if getattr(self.query, attr) is not notset: + raise sa_exc.InvalidRequestError( + "Can't call Query.update() or Query.delete() " + "when %s has been called" % + (methname, ) + ) @property def session(self): @@ -907,18 +1076,34 @@ class BulkUD(object): self._do_post_synchronize() self._do_post() - def _do_pre(self): + @util.dependencies("sqlalchemy.orm.query") + def _do_pre(self, querylib): query = self.query - self.context = context = query._compile_context() - if len(context.statement.froms) != 1 or \ - not isinstance(context.statement.froms[0], schema.Table): + self.context = querylib.QueryContext(query) + + if isinstance(query._entities[0], querylib._ColumnEntity): + # check for special case of query(table) + tables = set() + for ent in query._entities: + if not isinstance(ent, querylib._ColumnEntity): + tables.clear() + break + else: + tables.update(_from_objects(ent.column)) + if len(tables) != 1: + raise sa_exc.InvalidRequestError( + "This operation requires only one Table or " + "entity be specified as the target." + ) + else: + self.primary_table = tables.pop() + + else: self.primary_table = query._only_entity_zero( "This operation requires only one Table or " "entity be specified as the target." ).mapper.local_table - else: - self.primary_table = context.statement.froms[0] session = query.session @@ -973,10 +1158,12 @@ class BulkFetch(BulkUD): def _do_pre_synchronize(self): query = self.query session = query.session - select_stmt = self.context.statement.with_only_columns( + context = query._compile_context() + select_stmt = context.statement.with_only_columns( self.primary_table.primary_key) self.matched_rows = session.execute( select_stmt, + mapper=self.mapper, params=query._params).fetchall() @@ -985,9 +1172,7 @@ class BulkUpdate(BulkUD): def __init__(self, query, values): super(BulkUpdate, self).__init__(query) - self.query._no_select_modifiers("update") self.values = values - self.mapper = self.query._mapper_zero_or_none() @classmethod def factory(cls, query, synchronize_session, values): @@ -1033,7 +1218,8 @@ class BulkUpdate(BulkUD): self.context.whereclause, values) self.result = self.query.session.execute( - update_stmt, params=self.query._params) + update_stmt, params=self.query._params, + mapper=self.mapper) self.rowcount = self.result.rowcount def _do_post(self): @@ -1046,7 +1232,6 @@ class BulkDelete(BulkUD): def __init__(self, query): super(BulkDelete, self).__init__(query) - self.query._no_select_modifiers("delete") @classmethod def factory(cls, query, synchronize_session): @@ -1060,8 +1245,10 @@ class BulkDelete(BulkUD): delete_stmt = sql.delete(self.primary_table, self.context.whereclause) - self.result = self.query.session.execute(delete_stmt, - params=self.query._params) + self.result = self.query.session.execute( + delete_stmt, + params=self.query._params, + mapper=self.mapper) self.rowcount = self.result.rowcount def _do_post(self): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 62ea93fb3..5694f7255 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -1,5 +1,5 @@ # orm/properties.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -34,6 +34,13 @@ class ColumnProperty(StrategizedProperty): strategy_wildcard_key = 'column' + __slots__ = ( + '_orig_columns', 'columns', 'group', 'deferred', + 'instrument', 'comparator_factory', 'descriptor', 'extension', + 'active_history', 'expire_on_flush', 'info', 'doc', + 'strategy_class', '_creation_order', '_is_polymorphic_discriminator', + '_mapped_by_synonym', '_deferred_loader') + def __init__(self, *columns, **kwargs): """Provide a column-level property for use with a Mapper. @@ -109,6 +116,7 @@ class ColumnProperty(StrategizedProperty): **Deprecated.** Please see :class:`.AttributeEvents`. """ + super(ColumnProperty, self).__init__() self._orig_columns = [expression._labeled(c) for c in columns] self.columns = [expression._labeled(_orm_full_deannotate(c)) for c in columns] @@ -149,6 +157,12 @@ class ColumnProperty(StrategizedProperty): ("instrument", self.instrument) ) + @util.dependencies("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") + def _memoized_attr__deferred_column_loader(self, state, strategies): + return state.InstanceState._instance_level_callable_processor( + self.parent.class_manager, + strategies.LoadDeferredColumns(self.key), self.key) + @property def expression(self): """Return the primary column or expression for this ColumnProperty. @@ -206,7 +220,7 @@ class ColumnProperty(StrategizedProperty): elif dest_state.has_identity and self.key not in dest_dict: dest_state._expire_attributes(dest_dict, [self.key]) - class Comparator(PropComparator): + class Comparator(util.MemoizedSlots, PropComparator): """Produce boolean, comparison, and other operators for :class:`.ColumnProperty` attributes. @@ -224,24 +238,25 @@ class ColumnProperty(StrategizedProperty): :attr:`.TypeEngine.comparator_factory` """ - @util.memoized_instancemethod - def __clause_element__(self): + + __slots__ = '__clause_element__', 'info' + + def _memoized_method___clause_element__(self): if self.adapter: return self.adapter(self.prop.columns[0]) else: return self.prop.columns[0]._annotate({ - "parententity": self._parentmapper, - "parentmapper": self._parentmapper}) + "parententity": self._parententity, + "parentmapper": self._parententity}) - @util.memoized_property - def info(self): + def _memoized_attr_info(self): ce = self.__clause_element__() try: return ce.info except AttributeError: return self.prop.info - def __getattr__(self, key): + def _fallback_getattr(self, key): """proxy attribute access down to the mapped column. this allows user-defined comparison methods to be accessed. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index fce7a3665..9aa2e3d99 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1,5 +1,5 @@ # orm/query.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -26,7 +26,7 @@ from . import ( exc as orm_exc, loading ) from .base import _entity_descriptor, _is_aliased_class, \ - _is_mapped_class, _orm_columns, _generative + _is_mapped_class, _orm_columns, _generative, InspectionAttr from .path_registry import PathRegistry from .util import ( AliasedClass, ORMAdapter, join as orm_join, with_parent, aliased @@ -75,6 +75,7 @@ class Query(object): _having = None _distinct = False _prefixes = None + _suffixes = None _offset = None _limit = None _for_update_arg = None @@ -99,7 +100,8 @@ class Query(object): _with_options = () _with_hints = () _enable_single_crit = True - + _orm_only_adapt = True + _orm_only_from_obj_alias = True _current_path = _path_registry def __init__(self, entities, session=None): @@ -159,7 +161,6 @@ class Query(object): for from_obj in obj: info = inspect(from_obj) - if hasattr(info, 'mapper') and \ (info.is_mapper or info.is_aliased_class): self._select_from_entity = from_obj @@ -231,7 +232,8 @@ class Query(object): adapters = [] # do we adapt all expression elements or only those # tagged as 'ORM' constructs ? - orm_only = getattr(self, '_orm_only_adapt', orm_only) + if not self._orm_only_adapt: + orm_only = False if as_filter and self._filter_aliases: for fa in self._filter_aliases._visitor_iterator: @@ -248,7 +250,7 @@ class Query(object): # to all SQL constructs. adapters.append( ( - getattr(self, '_orm_only_from_obj_alias', orm_only), + orm_only if self._orm_only_from_obj_alias else False, self._from_obj_alias.replace ) ) @@ -285,8 +287,9 @@ class Query(object): return self._entities[0] def _mapper_zero(self): - return self._select_from_entity or \ - self._entity_zero().entity_zero + return self._select_from_entity \ + if self._select_from_entity is not None \ + else self._entity_zero().entity_zero @property def _mapper_entities(self): @@ -300,11 +303,14 @@ class Query(object): self._mapper_zero() ) - def _mapper_zero_or_none(self): - if self._primary_entity: - return self._primary_entity.mapper - else: - return None + def _bind_mapper(self): + ezero = self._mapper_zero() + if ezero is not None: + insp = inspect(ezero) + if not insp.is_clause_element: + return insp.mapper + + return None def _only_mapper_zero(self, rationale=None): if len(self._entities) > 1: @@ -393,22 +399,6 @@ class Query(object): % (meth, meth) ) - def _no_select_modifiers(self, meth): - if not self._enable_assertions: - return - for attr, methname, notset in ( - ('_limit', 'limit()', None), - ('_offset', 'offset()', None), - ('_order_by', 'order_by()', False), - ('_group_by', 'group_by()', False), - ('_distinct', 'distinct()', False), - ): - if getattr(self, attr) is not notset: - raise sa_exc.InvalidRequestError( - "Can't call Query.%s() when %s has been called" % - (meth, methname) - ) - def _get_options(self, populate_existing=None, version_check=None, only_load_props=None, @@ -810,7 +800,7 @@ class Query(object): foreign-key-to-primary-key criterion, will also use an operation equivalent to :meth:`~.Query.get` in order to retrieve the target value from the local identity map - before querying the database. See :doc:`/orm/loading` + before querying the database. See :doc:`/orm/loading_relationships` for further details on relationship loading. :param ident: A scalar or tuple value representing @@ -825,7 +815,9 @@ class Query(object): :return: The object instance, or ``None``. """ + return self._get_impl(ident, loading.load_on_ident) + def _get_impl(self, ident, fallback_fn): # convert composite types to individual args if hasattr(ident, '__composite_values__'): ident = ident.__composite_values__() @@ -856,7 +848,7 @@ class Query(object): return None return instance - return loading.load_on_ident(self, key) + return fallback_fn(self, key) @_generative() def correlate(self, *args): @@ -987,6 +979,7 @@ class Query(object): statement.correlate(None) q = self._from_selectable(fromclause) q._enable_single_crit = False + q._select_from_entity = self._mapper_zero() if entities: q._set_entities(entities) return q @@ -1003,7 +996,7 @@ class Query(object): '_limit', '_offset', '_joinpath', '_joinpoint', '_distinct', '_having', - '_prefixes', + '_prefixes', '_suffixes' ): self.__dict__.pop(attr, None) self._set_select_from([fromclause], True) @@ -1099,7 +1092,7 @@ class Query(object): Most supplied options regard changing how column- and relationship-mapped attributes are loaded. See the sections - :ref:`deferred` and :doc:`/orm/loading` for reference + :ref:`deferred` and :doc:`/orm/loading_relationships` for reference documentation. """ @@ -1740,6 +1733,14 @@ class Query(object): anonymously aliased. Subsequent calls to :meth:`~.Query.filter` and similar will adapt the incoming criterion to the target alias, until :meth:`~.Query.reset_joinpoint` is called. + :param isouter=False: If True, the join used will be a left outer join, + just as if the :meth:`.Query.outerjoin` method were called. This + flag is here to maintain consistency with the same flag as accepted + by :meth:`.FromClause.join` and other Core constructs. + + + .. versionadded:: 1.0.0 + :param from_joinpoint=False: When using ``aliased=True``, a setting of True here will cause the join to be from the most recent joined target, rather than starting back from the original @@ -1757,13 +1758,14 @@ class Query(object): SQLAlchemy versions was the primary ORM-level joining interface. """ - aliased, from_joinpoint = kwargs.pop('aliased', False),\ - kwargs.pop('from_joinpoint', False) + aliased, from_joinpoint, isouter = kwargs.pop('aliased', False),\ + kwargs.pop('from_joinpoint', False),\ + kwargs.pop('isouter', False) if kwargs: raise TypeError("unknown arguments: %s" % - ','.join(kwargs.keys)) + ', '.join(sorted(kwargs))) return self._join(props, - outerjoin=False, create_aliases=aliased, + outerjoin=isouter, create_aliases=aliased, from_joinpoint=from_joinpoint) def outerjoin(self, *props, **kwargs): @@ -1777,7 +1779,7 @@ class Query(object): kwargs.pop('from_joinpoint', False) if kwargs: raise TypeError("unknown arguments: %s" % - ','.join(kwargs)) + ', '.join(sorted(kwargs))) return self._join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint) @@ -1835,6 +1837,11 @@ class Query(object): left_entity = prop = None + if isinstance(onclause, interfaces.PropComparator): + of_type = getattr(onclause, '_of_type', None) + else: + of_type = None + if isinstance(onclause, util.string_types): left_entity = self._joinpoint_zero() @@ -1861,8 +1868,6 @@ class Query(object): if isinstance(onclause, interfaces.PropComparator): if right_entity is None: - right_entity = onclause.property.mapper - of_type = getattr(onclause, '_of_type', None) if of_type: right_entity = of_type else: @@ -1944,11 +1949,9 @@ class Query(object): from_obj, r_info.selectable): overlap = True break - elif sql_util.selectables_overlap(l_info.selectable, - r_info.selectable): - overlap = True - if overlap and l_info.selectable is r_info.selectable: + if (overlap or not create_aliases) and \ + l_info.selectable is r_info.selectable: raise sa_exc.InvalidRequestError( "Can't join table/selectable '%s' to itself" % l_info.selectable) @@ -2348,12 +2351,38 @@ class Query(object): .. versionadded:: 0.7.7 + .. seealso:: + + :meth:`.HasPrefixes.prefix_with` + """ if self._prefixes: self._prefixes += prefixes else: self._prefixes = prefixes + @_generative() + def suffix_with(self, *suffixes): + """Apply the suffix to the query and return the newly resulting + ``Query``. + + :param \*suffixes: optional suffixes, typically strings, + not using any commas. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Query.prefix_with` + + :meth:`.HasSuffixes.suffix_with` + + """ + if self._suffixes: + self._suffixes += suffixes + else: + self._suffixes = suffixes + def all(self): """Return the results represented by this ``Query`` as a list. @@ -2488,7 +2517,7 @@ class Query(object): def _execute_and_instances(self, querycontext): conn = self._connection_from_session( - mapper=self._mapper_zero_or_none(), + mapper=self._bind_mapper(), clause=querycontext.statement, close_with_result=True) @@ -2515,18 +2544,21 @@ class Query(object): 'type':User, 'aliased':False, 'expr':User, + 'entity': User }, { 'name':'id', 'type':Integer(), 'aliased':False, 'expr':User.id, + 'entity': User }, { 'name':'user2', 'type':User, 'aliased':True, - 'expr':user_alias + 'expr':user_alias, + 'entity': user_alias } ] @@ -2536,7 +2568,10 @@ class Query(object): 'name': ent._label_name, 'type': ent.type, 'aliased': getattr(ent, 'is_aliased_class', False), - 'expr': ent.expr + 'expr': ent.expr, + 'entity': + ent.entity_zero.entity if ent.entity_zero is not None + else None } for ent in self._entities ] @@ -2590,6 +2625,7 @@ class Query(object): 'offset': self._offset, 'distinct': self._distinct, 'prefixes': self._prefixes, + 'suffixes': self._suffixes, 'group_by': self._group_by or None, 'having': self._having } @@ -2686,6 +2722,18 @@ class Query(object): Deletes rows matched by this query from the database. + E.g.:: + + sess.query(User).filter(User.age == 25).\\ + delete(synchronize_session=False) + + sess.query(User).filter(User.age == 25).\\ + delete(synchronize_session='evaluate') + + .. warning:: The :meth:`.Query.delete` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + :param synchronize_session: chooses the strategy for the removal of matched objects from the session. Valid values are: @@ -2704,8 +2752,7 @@ class Query(object): ``'evaluate'`` - Evaluate the query's criteria in Python straight on the objects in the session. If evaluation of the criteria isn't - implemented, an error is raised. In that case you probably - want to use the 'fetch' strategy as a fallback. + implemented, an error is raised. The expression evaluator currently doesn't account for differing string collations between the database and Python. @@ -2713,29 +2760,42 @@ class Query(object): :return: the count of rows matched as returned by the database's "row count" feature. - This method has several key caveats: - - * The method does **not** offer in-Python cascading of relationships - - it is assumed that ON DELETE CASCADE/SET NULL/etc. is configured - for any foreign key references which require it, otherwise the - database may emit an integrity violation if foreign key references - are being enforced. - - After the DELETE, dependent objects in the :class:`.Session` which - were impacted by an ON DELETE may not contain the current - state, or may have been deleted. This issue is resolved once the - :class:`.Session` is expired, - which normally occurs upon :meth:`.Session.commit` or can be forced - by using :meth:`.Session.expire_all`. Accessing an expired object - whose row has been deleted will invoke a SELECT to locate the - row; when the row is not found, an - :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. - - * The :meth:`.MapperEvents.before_delete` and - :meth:`.MapperEvents.after_delete` - events are **not** invoked from this method. Instead, the - :meth:`.SessionEvents.after_bulk_delete` method is provided to act - upon a mass DELETE of entity rows. + .. warning:: **Additional Caveats for bulk query deletes** + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON DELETE CASCADE/SET + NULL/etc. is configured for any foreign key references + which require it, otherwise the database may emit an + integrity violation if foreign key references are being + enforced. + + After the DELETE, dependent objects in the + :class:`.Session` which were impacted by an ON DELETE + may not contain the current state, or may have been + deleted. This issue is resolved once the + :class:`.Session` is expired, which normally occurs upon + :meth:`.Session.commit` or can be forced by using + :meth:`.Session.expire_all`. Accessing an expired + object whose row has been deleted will invoke a SELECT + to locate the row; when the row is not found, an + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is + raised. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaluate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The :meth:`.MapperEvents.before_delete` and + :meth:`.MapperEvents.after_delete` + events **are not invoked** from this method. Instead, the + :meth:`.SessionEvents.after_bulk_delete` method is provided to + act upon a mass DELETE of entity rows. .. seealso:: @@ -2758,17 +2818,21 @@ class Query(object): E.g.:: - sess.query(User).filter(User.age == 25).\ - update({User.age: User.age - 10}, synchronize_session='fetch') + sess.query(User).filter(User.age == 25).\\ + update({User.age: User.age - 10}, synchronize_session=False) - - sess.query(User).filter(User.age == 25).\ + sess.query(User).filter(User.age == 25).\\ update({"age": User.age - 10}, synchronize_session='evaluate') + .. warning:: The :meth:`.Query.update` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + + :param values: a dictionary with attributes names, or alternatively - mapped attributes or SQL expressions, as keys, and literal - values or sql expressions as values. + mapped attributes or SQL expressions, as keys, and literal + values or sql expressions as values. .. versionchanged:: 1.0.0 - string names in the values dictionary are now resolved against the mapped entity; previously, these @@ -2776,7 +2840,7 @@ class Query(object): translation. :param synchronize_session: chooses the strategy to update the - attributes on objects in the session. Valid values are: + attributes on objects in the session. Valid values are: ``False`` - don't synchronize the session. This option is the most efficient and is reliable once the session is expired, which @@ -2797,43 +2861,56 @@ class Query(object): string collations between the database and Python. :return: the count of rows matched as returned by the database's - "row count" feature. - - This method has several key caveats: - - * The method does **not** offer in-Python cascading of relationships - - it is assumed that ON UPDATE CASCADE is configured for any foreign - key references which require it, otherwise the database may emit an - integrity violation if foreign key references are being enforced. - - After the UPDATE, dependent objects in the :class:`.Session` which - were impacted by an ON UPDATE CASCADE may not contain the current - state; this issue is resolved once the :class:`.Session` is expired, - which normally occurs upon :meth:`.Session.commit` or can be forced - by using :meth:`.Session.expire_all`. - - * The method supports multiple table updates, as - detailed in :ref:`multi_table_updates`, and this behavior does - extend to support updates of joined-inheritance and other multiple - table mappings. However, the **join condition of an inheritance - mapper is currently not automatically rendered**. - Care must be taken in any multiple-table update to explicitly - include the joining condition between those tables, even in mappings - where this is normally automatic. - E.g. if a class ``Engineer`` subclasses ``Employee``, an UPDATE of - the ``Engineer`` local table using criteria against the ``Employee`` - local table might look like:: - - session.query(Engineer).\\ - filter(Engineer.id == Employee.id).\\ - filter(Employee.name == 'dilbert').\\ - update({"engineer_type": "programmer"}) - - * The :meth:`.MapperEvents.before_update` and - :meth:`.MapperEvents.after_update` - events are **not** invoked from this method. Instead, the - :meth:`.SessionEvents.after_bulk_update` method is provided to act - upon a mass UPDATE of entity rows. + "row count" feature. + + .. warning:: **Additional Caveats for bulk query updates** + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON UPDATE CASCADE is + configured for any foreign key references which require + it, otherwise the database may emit an integrity + violation if foreign key references are being enforced. + + After the UPDATE, dependent objects in the + :class:`.Session` which were impacted by an ON UPDATE + CASCADE may not contain the current state; this issue is + resolved once the :class:`.Session` is expired, which + normally occurs upon :meth:`.Session.commit` or can be + forced by using :meth:`.Session.expire_all`. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaluate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The method supports multiple table updates, as detailed + in :ref:`multi_table_updates`, and this behavior does + extend to support updates of joined-inheritance and + other multiple table mappings. However, the **join + condition of an inheritance mapper is not + automatically rendered**. Care must be taken in any + multiple-table update to explicitly include the joining + condition between those tables, even in mappings where + this is normally automatic. E.g. if a class ``Engineer`` + subclasses ``Employee``, an UPDATE of the ``Engineer`` + local table using criteria against the ``Employee`` + local table might look like:: + + session.query(Engineer).\\ + filter(Engineer.id == Employee.id).\\ + filter(Employee.name == 'dilbert').\\ + update({"engineer_type": "programmer"}) + + * The :meth:`.MapperEvents.before_update` and + :meth:`.MapperEvents.after_update` + events **are not invoked from this method**. Instead, the + :meth:`.SessionEvents.after_bulk_update` method is provided to + act upon a mass UPDATE of entity rows. .. seealso:: @@ -2849,6 +2926,12 @@ class Query(object): return update_op.rowcount def _compile_context(self, labels=True): + if self.dispatch.before_compile: + for fn in self.dispatch.before_compile: + new_query = fn(self) + if new_query is not None: + self = new_query + context = QueryContext(self) if context.statement is not None: @@ -2869,10 +2952,8 @@ class Query(object): # "load from explicit FROMs" mode, # i.e. when select_from() or join() is used context.froms = list(context.from_clause) - else: - # "load from discrete FROMs" mode, - # i.e. when each _MappedEntity has its own FROM - context.froms = context.froms + # else "load from discrete FROMs" mode, + # i.e. when each _MappedEntity has its own FROM if self._enable_single_crit: self._adjust_for_single_inheritance(context) @@ -2892,6 +2973,7 @@ class Query(object): context.statement = self._compound_eager_statement(context) else: context.statement = self._simple_statement(context) + return context def _compound_eager_statement(self, context): @@ -3189,25 +3271,21 @@ class _MapperEntity(_QueryEntity): self.mapper._equivalent_columns) if query._primary_entity is self: - _instance = loading.instance_processor( - self.mapper, - context, - result, - self.path, - adapter, - only_load_props=query._only_load_props, - refresh_state=context.refresh_state, - polymorphic_discriminator=self._polymorphic_discriminator - ) + only_load_props = query._only_load_props + refresh_state = context.refresh_state else: - _instance = loading.instance_processor( - self.mapper, - context, - result, - self.path, - adapter, - polymorphic_discriminator=self._polymorphic_discriminator - ) + only_load_props = refresh_state = None + + _instance = loading._instance_processor( + self.mapper, + context, + result, + self.path, + adapter, + only_load_props=only_load_props, + refresh_state=refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator + ) return _instance, self._label_name @@ -3228,41 +3306,19 @@ class _MapperEntity(_QueryEntity): ) ) - if self._with_polymorphic: - poly_properties = self.mapper._iterate_polymorphic_properties( - self._with_polymorphic) - else: - poly_properties = self.mapper._polymorphic_properties - - for value in poly_properties: - if query._only_load_props and \ - value.key not in query._only_load_props: - continue - value.setup( - context, - self, - self.path, - adapter, - only_load_props=query._only_load_props, - column_collection=context.primary_columns - ) - - if self._polymorphic_discriminator is not None and \ - self._polymorphic_discriminator \ - is not self.mapper.polymorphic_on: - - if adapter: - pd = adapter.columns[self._polymorphic_discriminator] - else: - pd = self._polymorphic_discriminator - context.primary_columns.append(pd) + loading._setup_entity_query( + context, self.mapper, self, + self.path, adapter, context.primary_columns, + with_polymorphic=self._with_polymorphic, + only_load_props=query._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator) def __str__(self): return str(self.mapper) @inspection._self_inspects -class Bundle(object): +class Bundle(InspectionAttr): """A grouping of SQL expressions that are returned by a :class:`.Query` under one namespace. @@ -3285,6 +3341,12 @@ class Bundle(object): """If True, queries for a single Bundle will be returned as a single entity, rather than an element within a keyed tuple.""" + is_clause_element = False + + is_mapper = False + + is_aliased_class = False + def __init__(self, name, *exprs, **kw): """Construct a new :class:`.Bundle`. @@ -3395,7 +3457,6 @@ class _BundleEntity(_QueryEntity): self.supports_single_entity = self.bundle.single_entity - @property def entity_zero(self): for ent in self._entities: @@ -3453,36 +3514,43 @@ class _ColumnEntity(_QueryEntity): def __init__(self, query, column, namespace=None): self.expr = column self.namespace = namespace + search_entities = True if isinstance(column, util.string_types): column = sql.literal_column(column) self._label_name = column.name + search_entities = False + _entity = None elif isinstance(column, ( attributes.QueryableAttribute, interfaces.PropComparator )): + _entity = getattr(column, '_parententity', None) + if _entity is not None: + search_entities = False self._label_name = column.key column = column._query_clause_element() - else: - self._label_name = getattr(column, 'key', None) - - if not isinstance(column, expression.ColumnElement) and \ - hasattr(column, '_select_iterable'): - for c in column._select_iterable: - if c is column: - break - _ColumnEntity(query, c, namespace=column) - else: + if isinstance(column, Bundle): + _BundleEntity(query, column) return - elif isinstance(column, Bundle): - _BundleEntity(query, column) - return + elif not isinstance(column, sql.ColumnElement): + if hasattr(column, '_select_iterable'): + # break out an object like Table into + # individual columns + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c, namespace=column) + else: + return - if not isinstance(column, sql.ColumnElement): raise sa_exc.InvalidRequestError( "SQL expression, column, or mapped entity " "expected - got '%r'" % (column, ) ) + else: + self._label_name = getattr(column, 'key', None) + search_entities = True self.type = type_ = column.type if type_.hashable: @@ -3513,19 +3581,38 @@ class _ColumnEntity(_QueryEntity): # leaking out their entities into the main select construct self.actual_froms = actual_froms = set(column._from_objects) - self.entities = util.OrderedSet( - elem._annotations['parententity'] - for elem in visitors.iterate(column, {}) - if 'parententity' in elem._annotations - and actual_froms.intersection(elem._from_objects) - ) - - if self.entities: - self.entity_zero = list(self.entities)[0] - elif self.namespace is not None: - self.entity_zero = self.namespace + if not search_entities: + self.entity_zero = _entity + if _entity: + self.entities = [_entity] + else: + self.entities = [] + self._from_entities = set(self.entities) else: - self.entity_zero = None + all_elements = [ + elem for elem in visitors.iterate(column, {}) + if 'parententity' in elem._annotations + ] + + self.entities = util.unique_list([ + elem._annotations['parententity'] + for elem in all_elements + if 'parententity' in elem._annotations + ]) + + self._from_entities = set([ + elem._annotations['parententity'] + for elem in all_elements + if 'parententity' in elem._annotations + and actual_froms.intersection(elem._from_objects) + ]) + + if self.entities: + self.entity_zero = self.entities[0] + elif self.namespace is not None: + self.entity_zero = self.namespace + else: + self.entity_zero = None supports_single_entity = False @@ -3547,7 +3634,9 @@ class _ColumnEntity(_QueryEntity): def setup_entity(self, ext_info, aliased_adapter): if 'selectable' not in self.__dict__: self.selectable = ext_info.selectable - self.froms.add(ext_info.selectable) + + if self.actual_froms.intersection(ext_info.selectable._from_objects): + self.froms.add(ext_info.selectable) def corresponds_to(self, entity): # TODO: just returning False here, @@ -3561,12 +3650,11 @@ class _ColumnEntity(_QueryEntity): return not _is_aliased_class(self.entity_zero) and \ entity.common_parent(self.entity_zero) - def _resolve_expr_against_query_aliases(self, query, expr, context): - return query._adapt_clause(expr, False, True) - def row_processor(self, query, context, result): - column = self._resolve_expr_against_query_aliases( - query, self.column, context) + if ('fetch_column', self) in context.attributes: + column = context.attributes[('fetch_column', self)] + else: + column = query._adapt_clause(self.column, False, True) if context.adapter: column = context.adapter.columns[column] @@ -3575,20 +3663,26 @@ class _ColumnEntity(_QueryEntity): return getter, self._label_name def setup_context(self, query, context): - column = self._resolve_expr_against_query_aliases( - query, self.column, context) + column = query._adapt_clause(self.column, False, True) context.froms += tuple(self.froms) context.primary_columns.append(column) + context.attributes[('fetch_column', self)] = column + def __str__(self): return str(self.column) class QueryContext(object): - multi_row_eager_loaders = False - adapter = None - froms = () - for_update = None + __slots__ = ( + 'multi_row_eager_loaders', 'adapter', 'froms', 'for_update', + 'query', 'session', 'autoflush', 'populate_existing', + 'invoke_all_eagers', 'version_check', 'refresh_state', + 'primary_columns', 'secondary_columns', 'eager_order_by', + 'eager_joins', 'create_eager_joins', 'propagate_options', + 'attributes', 'statement', 'from_clause', 'whereclause', + 'order_by', 'labels', '_for_update_arg', 'runid', 'partials' + ) def __init__(self, query): @@ -3605,8 +3699,13 @@ class QueryContext(object): self.whereclause = query._criterion self.order_by = query._order_by + self.multi_row_eager_loaders = False + self.adapter = None + self.froms = () + self.for_update = None self.query = query self.session = query.session + self.autoflush = query._autoflush self.populate_existing = query._populate_existing self.invoke_all_eagers = query._invoke_all_eagers self.version_check = query._version_check diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 56a33742d..e36a644da 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1,5 +1,5 @@ # orm/relationships.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -16,13 +16,14 @@ and `secondaryjoin` aspects of :func:`.relationship`. from __future__ import absolute_import from .. import sql, util, exc as sa_exc, schema, log +import weakref from .util import CascadeOptions, _orm_annotate, _orm_deannotate from . import dependency from . import attributes from ..sql.util import ( ClauseAdapter, join_condition, _shallow_annotate, visit_binary_product, - _deep_deannotate, selectables_overlap + _deep_deannotate, selectables_overlap, adapt_criterion_to_null ) from ..sql import operators, expression, visitors from .interfaces import (MANYTOMANY, MANYTOONE, ONETOMANY, @@ -112,6 +113,7 @@ class RelationshipProperty(StrategizedProperty): active_history=False, cascade_backrefs=True, load_on_pending=False, + bake_queries=True, strategy_class=None, _local_remote_pairs=None, query_class=None, info=None): @@ -273,6 +275,15 @@ class RelationshipProperty(StrategizedProperty): :paramref:`~.relationship.backref` - alternative form of backref specification. + :param bake_queries: + Use the :class:`.BakedQuery` cache to cache queries used in lazy + loads. True by default, as this typically improves performance + significantly. Set to False to reduce ORM memory use, or + if unresolved stability issues are observed with the baked query + cache system. + + .. versionadded:: 1.0.0 + :param cascade: a comma-separated list of cascade rules which determines how Session operations should be "cascaded" from parent to child. @@ -527,7 +538,7 @@ class RelationshipProperty(StrategizedProperty): .. seealso:: - :doc:`/orm/loading` - Full documentation on relationship loader + :doc:`/orm/loading_relationships` - Full documentation on relationship loader configuration. :ref:`dynamic_relationship` - detail on the ``dynamic`` option. @@ -774,6 +785,7 @@ class RelationshipProperty(StrategizedProperty): """ + super(RelationshipProperty, self).__init__() self.uselist = uselist self.argument = argument @@ -800,6 +812,7 @@ class RelationshipProperty(StrategizedProperty): self.join_depth = join_depth self.local_remote_pairs = _local_remote_pairs self.extension = extension + self.bake_queries = bake_queries self.load_on_pending = load_on_pending self.comparator_factory = comparator_factory or \ RelationshipProperty.Comparator @@ -871,13 +884,13 @@ class RelationshipProperty(StrategizedProperty): """ self.prop = prop - self._parentmapper = parentmapper + self._parententity = parentmapper self._adapt_to_entity = adapt_to_entity if of_type: self._of_type = of_type def adapt_to_entity(self, adapt_to_entity): - return self.__class__(self.property, self._parentmapper, + return self.__class__(self.property, self._parententity, adapt_to_entity=adapt_to_entity, of_type=self._of_type) @@ -929,7 +942,7 @@ class RelationshipProperty(StrategizedProperty): """ return RelationshipProperty.Comparator( self.property, - self._parentmapper, + self._parententity, adapt_to_entity=self._adapt_to_entity, of_type=cls) @@ -1289,8 +1302,9 @@ class RelationshipProperty(StrategizedProperty): """ if isinstance(other, (util.NoneType, expression.Null)): if self.property.direction == MANYTOONE: - return sql.or_(*[x != None for x in - self.property._calculated_foreign_keys]) + return _orm_annotate(~self.property._optimized_compare( + None, adapt_source=self.adapter)) + else: return self._criterion_exists() elif self.property.uselist: @@ -1299,7 +1313,7 @@ class RelationshipProperty(StrategizedProperty): " to an object or collection; use " "contains() to test for membership.") else: - return self.__negated_contains_or_equals(other) + return _orm_annotate(self.__negated_contains_or_equals(other)) @util.memoized_property def property(self): @@ -1312,16 +1326,69 @@ class RelationshipProperty(StrategizedProperty): return self._optimized_compare( instance, value_is_parent=True, alias_secondary=alias_secondary) - def _optimized_compare(self, value, value_is_parent=False, + def _optimized_compare(self, state, value_is_parent=False, adapt_source=None, alias_secondary=True): - if value is not None: - value = attributes.instance_state(value) - return self._lazy_strategy.lazy_clause( - value, - reverse_direction=not value_is_parent, - alias_secondary=alias_secondary, - adapt_source=adapt_source) + if state is not None: + state = attributes.instance_state(state) + + reverse_direction = not value_is_parent + + if state is None: + return self._lazy_none_clause( + reverse_direction, + adapt_source=adapt_source) + + if not reverse_direction: + criterion, bind_to_col = \ + self._lazy_strategy._lazywhere, \ + self._lazy_strategy._bind_to_col + else: + criterion, bind_to_col = \ + self._lazy_strategy._rev_lazywhere, \ + self._lazy_strategy._rev_bind_to_col + + if reverse_direction: + mapper = self.mapper + else: + mapper = self.parent + + dict_ = attributes.instance_dict(state.obj()) + + def visit_bindparam(bindparam): + if bindparam._identifying_key in bind_to_col: + bindparam.callable = \ + lambda: mapper._get_state_attr_by_column( + state, dict_, + bind_to_col[bindparam._identifying_key]) + + if self.secondary is not None and alias_secondary: + criterion = ClauseAdapter( + self.secondary.alias()).\ + traverse(criterion) + + criterion = visitors.cloned_traverse( + criterion, {}, {'bindparam': visit_bindparam}) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): + if not reverse_direction: + criterion, bind_to_col = \ + self._lazy_strategy._lazywhere, \ + self._lazy_strategy._bind_to_col + else: + criterion, bind_to_col = \ + self._lazy_strategy._rev_lazywhere, \ + self._lazy_strategy._rev_bind_to_col + + criterion = adapt_criterion_to_null(criterion, bind_to_col) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion def __str__(self): return str(self.parent.class_.__name__) + "." + self.key @@ -1532,6 +1599,7 @@ class RelationshipProperty(StrategizedProperty): self._check_cascade_settings(self._cascade) self._post_init() self._generate_backref() + self._join_condition._warn_for_conflicting_sync_targets() super(RelationshipProperty, self).do_init() self._lazy_strategy = self._get_strategy((("lazy", "select"),)) @@ -2519,6 +2587,60 @@ class JoinCondition(object): self.secondary_synchronize_pairs = \ self._deannotate_pairs(secondary_sync_pairs) + _track_overlapping_sync_targets = weakref.WeakKeyDictionary() + + def _warn_for_conflicting_sync_targets(self): + if not self.support_sync: + return + + # we would like to detect if we are synchronizing any column + # pairs in conflict with another relationship that wishes to sync + # an entirely different column to the same target. This is a + # very rare edge case so we will try to minimize the memory/overhead + # impact of this check + for from_, to_ in [ + (from_, to_) for (from_, to_) in self.synchronize_pairs + ] + [ + (from_, to_) for (from_, to_) in self.secondary_synchronize_pairs + ]: + # save ourselves a ton of memory and overhead by only + # considering columns that are subject to a overlapping + # FK constraints at the core level. This condition can arise + # if multiple relationships overlap foreign() directly, but + # we're going to assume it's typically a ForeignKeyConstraint- + # level configuration that benefits from this warning. + if len(to_.foreign_keys) < 2: + continue + + if to_ not in self._track_overlapping_sync_targets: + self._track_overlapping_sync_targets[to_] = \ + weakref.WeakKeyDictionary({self.prop: from_}) + else: + other_props = [] + prop_to_from = self._track_overlapping_sync_targets[to_] + for pr, fr_ in prop_to_from.items(): + if pr.mapper in mapperlib._mapper_registry and \ + fr_ is not from_ and \ + pr not in self.prop._reverse_property: + other_props.append((pr, fr_)) + + if other_props: + util.warn( + "relationship '%s' will copy column %s to column %s, " + "which conflicts with relationship(s): %s. " + "Consider applying " + "viewonly=True to read-only relationships, or provide " + "a primaryjoin condition marking writable columns " + "with the foreign() annotation." % ( + self.prop, + from_, to_, + ", ".join( + "'%s' (copies %s to %s)" % (pr, fr_, to_) + for (pr, fr_) in other_props) + ) + ) + self._track_overlapping_sync_targets[to_][self.prop] = from_ + @util.memoized_property def remote_columns(self): return self._gather_join_annotations("remote") @@ -2635,27 +2757,31 @@ class JoinCondition(object): def create_lazy_clause(self, reverse_direction=False): binds = util.column_dict() - lookup = collections.defaultdict(list) equated_columns = util.column_dict() - if reverse_direction and self.secondaryjoin is None: - for l, r in self.local_remote_pairs: - lookup[r].append((r, l)) - equated_columns[l] = r - else: - # replace all "local side" columns, which is - # anything that isn't marked "remote" + has_secondary = self.secondaryjoin is not None + + if has_secondary: + lookup = collections.defaultdict(list) for l, r in self.local_remote_pairs: lookup[l].append((l, r)) equated_columns[r] = l + elif not reverse_direction: + for l, r in self.local_remote_pairs: + equated_columns[r] = l + else: + for l, r in self.local_remote_pairs: + equated_columns[l] = r def col_to_bind(col): - if (reverse_direction and col in lookup) or \ - (not reverse_direction and "local" in col._annotations): - if col in lookup: - for tobind, equated in lookup[col]: - if equated in binds: - return None + + if ( + (not reverse_direction and 'local' in col._annotations) or + reverse_direction and ( + (has_secondary and col in lookup) or + (not has_secondary and 'remote' in col._annotations) + ) + ): if col not in binds: binds[col] = sql.bindparam( None, None, type_=col.type, unique=True) diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 71648d126..b3f2fa5db 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -1,5 +1,5 @@ # orm/scoping.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index db9d3a51d..4619027e5 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1,5 +1,5 @@ # orm/session.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -20,6 +20,8 @@ from .base import ( _class_to_mapper, _state_mapper, object_state, _none_set, state_str, instance_str ) +import itertools +from . import persistence from .unitofwork import UOWTransaction from . import state as statelib import sys @@ -224,10 +226,10 @@ class SessionTransaction(object): def _is_transaction_boundary(self): return self.nested or not self._parent - def connection(self, bindkey, **kwargs): + def connection(self, bindkey, execution_options=None, **kwargs): self._assert_active() bind = self.session.get_bind(bindkey, **kwargs) - return self._connection_for_bind(bind) + return self._connection_for_bind(bind, execution_options) def _begin(self, nested=False): self._assert_active() @@ -235,14 +237,21 @@ class SessionTransaction(object): self.session, self, nested=nested) def _iterate_parents(self, upto=None): - if self._parent is upto: - return (self,) - else: - if self._parent is None: + + current = self + result = () + while current: + result += (current, ) + if current._parent is upto: + break + elif current._parent is None: raise sa_exc.InvalidRequestError( "Transaction %s is not on the active transaction list" % ( upto)) - return (self,) + self._parent._iterate_parents(upto) + else: + current = current._parent + + return result def _take_snapshot(self): if not self._is_transaction_boundary: @@ -291,22 +300,27 @@ class SessionTransaction(object): if not self.nested and self.session.expire_on_commit: for s in self.session.identity_map.all_states(): s._expire(s.dict, self.session.identity_map._modified) - for s in self._deleted: - s.session_id = None + for s in list(self._deleted): + s._detach() self._deleted.clear() elif self.nested: self._parent._new.update(self._new) + self._parent._dirty.update(self._dirty) self._parent._deleted.update(self._deleted) self._parent._key_switches.update(self._key_switches) - def _connection_for_bind(self, bind): + def _connection_for_bind(self, bind, execution_options): self._assert_active() if bind in self._connections: + if execution_options: + util.warn( + "Connection is already established for the " + "given bind; execution_options ignored") return self._connections[bind][0] if self._parent: - conn = self._parent._connection_for_bind(bind) + conn = self._parent._connection_for_bind(bind, execution_options) if not self.nested: return conn else: @@ -319,6 +333,9 @@ class SessionTransaction(object): else: conn = bind.contextual_connect() + if execution_options: + conn = conn.execution_options(**execution_options) + if self.session.twophase and self._parent is None: transaction = conn.begin_twophase() elif self.nested: @@ -395,26 +412,29 @@ class SessionTransaction(object): for subtransaction in stx._iterate_parents(upto=self): subtransaction.close() + boundary = self if self._state in (ACTIVE, PREPARED): for transaction in self._iterate_parents(): if transaction._parent is None or transaction.nested: transaction._rollback_impl() transaction._state = DEACTIVE + boundary = transaction break else: transaction._state = DEACTIVE sess = self.session - if self.session._enable_transaction_accounting and \ + if sess._enable_transaction_accounting and \ not sess._is_clean(): + # if items were added, deleted, or mutated # here, we need to re-restore the snapshot util.warn( "Session's state has been changed on " "a non-active transaction - this state " "will be discarded.") - self._restore_snapshot(dirty_only=self.nested) + boundary._restore_snapshot(dirty_only=boundary.nested) self.close() if self._parent and _capture_exception: @@ -433,11 +453,13 @@ class SessionTransaction(object): self.session.dispatch.after_rollback(self.session) - def close(self): + def close(self, invalidate=False): self.session.transaction = self._parent if self._parent is None: for connection, transaction, autoclose in \ set(self._connections.values()): + if invalidate: + connection.invalidate() if autoclose: connection.close() else: @@ -482,7 +504,8 @@ class Session(_SessionClassMethods): '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', 'close', 'commit', 'connection', 'delete', 'execute', 'expire', 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', - 'is_modified', + 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings', + 'bulk_update_mappings', 'merge', 'query', 'refresh', 'rollback', 'scalar') @@ -591,8 +614,8 @@ class Session(_SessionClassMethods): .. versionadded:: 0.9.0 :param query_cls: Class which should be used to create new Query - objects, as returned by the :meth:`~.Session.query` method. - Defaults to :class:`.Query`. + objects, as returned by the :meth:`~.Session.query` method. + Defaults to :class:`.Query`. :param twophase: When ``True``, all transactions will be started as a "two phase" transaction, i.e. using the "two phase" semantics @@ -788,6 +811,7 @@ class Session(_SessionClassMethods): def connection(self, mapper=None, clause=None, bind=None, close_with_result=False, + execution_options=None, **kw): """Return a :class:`.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -832,6 +856,18 @@ class Session(_SessionClassMethods): configured with ``autocommit=True`` and does not already have a transaction in progress. + :param execution_options: a dictionary of execution options that will + be passed to :meth:`.Connection.execution_options`, **when the + connection is first procured only**. If the connection is already + present within the :class:`.Session`, a warning is emitted and + the arguments are ignored. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :ref:`session_transaction_isolation` + :param \**kw: Additional keyword arguments are sent to :meth:`get_bind()`, allowing additional arguments to be passed to custom @@ -842,13 +878,18 @@ class Session(_SessionClassMethods): bind = self.get_bind(mapper, clause=clause, **kw) return self._connection_for_bind(bind, - close_with_result=close_with_result) + close_with_result=close_with_result, + execution_options=execution_options) - def _connection_for_bind(self, engine, **kwargs): + def _connection_for_bind(self, engine, execution_options=None, **kw): if self.transaction is not None: - return self.transaction._connection_for_bind(engine) + return self.transaction._connection_for_bind( + engine, execution_options) else: - return engine.contextual_connect(**kwargs) + conn = engine.contextual_connect(**kw) + if execution_options: + conn = conn.execution_options(**execution_options) + return conn def execute(self, clause, params=None, mapper=None, bind=None, **kw): """Execute a SQL expression construct or string statement within @@ -997,10 +1038,46 @@ class Session(_SessionClassMethods): not use any connection resources until they are first needed. """ + self._close_impl(invalidate=False) + + def invalidate(self): + """Close this Session, using connection invalidation. + + This is a variant of :meth:`.Session.close` that will additionally + ensure that the :meth:`.Connection.invalidate` method will be called + on all :class:`.Connection` objects. This can be called when + the database is known to be in a state where the connections are + no longer safe to be used. + + E.g.:: + + try: + sess = Session() + sess.add(User()) + sess.commit() + except gevent.Timeout: + sess.invalidate() + raise + except: + sess.rollback() + raise + + This clears all items and ends any transaction in progress. + + If this session were created with ``autocommit=False``, a new + transaction is immediately begun. Note that this new transaction does + not use any connection resources until they are first needed. + + .. versionadded:: 0.9.9 + + """ + self._close_impl(invalidate=True) + + def _close_impl(self, invalidate): self.expunge_all() if self.transaction is not None: for transaction in self.transaction._iterate_parents(): - transaction.close() + transaction.close(invalidate) def expunge_all(self): """Remove all object instances from this ``Session``. @@ -1409,6 +1486,7 @@ class Session(_SessionClassMethods): state._detach() elif self.transaction: self.transaction._deleted.pop(state, None) + state._detach() def _register_newly_persistent(self, states): for state in states: @@ -1753,7 +1831,7 @@ class Session(_SessionClassMethods): "function to send this object back to the transient state." % state_str(state) ) - self._before_attach(state) + self._before_attach(state, check_identity_map=False) self._deleted.pop(state, None) if discard_existing: self.identity_map.replace(state) @@ -1833,13 +1911,12 @@ class Session(_SessionClassMethods): self._attach(state, include_before=True) state._load_pending = True - def _before_attach(self, state): + def _before_attach(self, state, check_identity_map=True): if state.session_id != self.hash_key and \ self.dispatch.before_attach: self.dispatch.before_attach(self, state.obj()) - def _attach(self, state, include_before=False): - if state.key and \ + if check_identity_map and state.key and \ state.key in self.identity_map and \ not self.identity_map.contains_state(state): raise sa_exc.InvalidRequestError( @@ -1855,10 +1932,11 @@ class Session(_SessionClassMethods): "(this is '%s')" % (state_str(state), state.session_id, self.hash_key)) + def _attach(self, state, include_before=False): + if state.session_id != self.hash_key: - if include_before and \ - self.dispatch.before_attach: - self.dispatch.before_attach(self, state.obj()) + if include_before: + self._before_attach(state) state.session_id = self.hash_key if state.modified and state._strong_obj is None: state._strong_obj = state.obj() @@ -2043,6 +2121,226 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) + def bulk_save_objects( + self, objects, return_defaults=False, update_changed_only=True): + """Perform a bulk save of the given list of objects. + + The bulk save feature allows mapped objects to be used as the + source of simple INSERT and UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations; the extraction of data from the objects is also performed + using a lower-latency process that ignores whether or not attributes + have actually been modified in the case of UPDATEs, and also ignores + SQL expressions. + + The objects as given are not added to the session and no additional + state is established on them, unless the ``return_defaults`` flag + is also set, in which case primary key attributes and server-side + default values will be populated. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk save feature allows for a lower-latency INSERT/UPDATE + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT/UPDATES of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param objects: a list of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. + + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` + + """ + for (mapper, isupdate), states in itertools.groupby( + (attributes.instance_state(obj) for obj in objects), + lambda state: (state.mapper, state.key is not None) + ): + self._bulk_save_mappings( + mapper, states, isupdate, True, + return_defaults, update_changed_only) + + def bulk_insert_mappings(self, mapper, mappings, return_defaults=False): + """Perform a bulk insert of the given list of mapping dictionaries. + + The bulk insert feature allows plain Python dictionaries to be used as + the source of simple INSERT operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when inserting + large numbers of simple rows. + + The values within the dictionaries as given are typically passed + without modification into Core :meth:`.Insert` constructs, after + organizing the values within them across the tables to which + the given mapper is mapped. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk insert feature allows for a lower-latency INSERT + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be inserted, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain + all keys to be populated into all tables. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary + key values ahead of time; however, + :paramref:`.Session.bulk_insert_mappings.return_defaults` + **greatly reduces the performance gains** of the method overall. + If the rows + to be inserted only refer to a single table, then there is no + reason this flag should be set as the returned default information + is not used. + + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + """ + self._bulk_save_mappings( + mapper, mappings, False, False, return_defaults, False) + + def bulk_update_mappings(self, mapper, mappings): + """Perform a bulk update of the given list of mapping dictionaries. + + The bulk update feature allows plain Python dictionaries to be used as + the source of simple UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when updating + large numbers of simple rows. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk update feature allows for a lower-latency UPDATE + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + UPDATES of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary may contain + keys corresponding to all tables. All those keys which are present + and are not part of the primary key are applied to the SET clause + of the UPDATE statement; the primary key values, which are required, + are applied to the WHERE clause. + + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + """ + self._bulk_save_mappings(mapper, mappings, True, False, False, False) + + def _bulk_save_mappings( + self, mapper, mappings, isupdate, isstates, + return_defaults, update_changed_only): + mapper = _class_to_mapper(mapper) + self._flushing = True + + transaction = self.begin( + subtransactions=True) + try: + if isupdate: + persistence._bulk_update( + mapper, mappings, transaction, + isstates, update_changed_only) + else: + persistence._bulk_insert( + mapper, mappings, transaction, isstates, return_defaults) + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + finally: + self._flushing = False + def is_modified(self, instance, include_collections=True, passive=True): """Return ``True`` if the given instance has locally @@ -2404,9 +2702,13 @@ def make_transient(instance): if s: s._expunge_state(state) - # remove expired state and - # deferred callables - state.callables.clear() + # remove expired state + state.expired_attributes.clear() + + # remove deferred callables + if state.callables: + del state.callables + if state.key: del state.key if state.deleted: @@ -2449,16 +2751,19 @@ def make_transient_to_detached(instance): def object_session(instance): - """Return the ``Session`` to which instance belongs. + """Return the :class:`.Session` to which the given instance belongs. - If the instance is not a mapped instance, an error is raised. + This is essentially the same as the :attr:`.InstanceState.session` + accessor. See that attribute for details. """ try: - return _state_session(attributes.instance_state(instance)) + state = attributes.instance_state(instance) except exc.NO_STATE: raise exc.UnmappedInstanceError(instance) + else: + return _state_session(state) _new_sessionid = util.counter() diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 4756f1707..6034e74de 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -1,5 +1,5 @@ # orm/state.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -60,12 +60,33 @@ class InstanceState(interfaces.InspectionAttr): _load_pending = False is_instance = True + callables = () + """A namespace where a per-state loader callable can be associated. + + In SQLAlchemy 1.0, this is only used for lazy loaders / deferred + loaders that were set up via query option. + + Previously, callables was used also to indicate expired attributes + by storing a link to the InstanceState itself in this dictionary. + This role is now handled by the expired_attributes set. + + """ + def __init__(self, obj, manager): self.class_ = obj.__class__ self.manager = manager self.obj = weakref.ref(obj, self._cleanup) self.committed_state = {} - self.callables = {} + self.expired_attributes = set() + + expired_attributes = None + """The set of keys which are 'expired' to be loaded by + the manager's deferred scalar loader, assuming no pending + changes. + + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs.""" + @util.memoized_property def attrs(self): @@ -145,7 +166,16 @@ class InstanceState(interfaces.InspectionAttr): @util.dependencies("sqlalchemy.orm.session") def session(self, sessionlib): """Return the owning :class:`.Session` for this instance, - or ``None`` if none available.""" + or ``None`` if none available. + + Note that the result here can in some cases be *different* + from that of ``obj in session``; an object that's been deleted + will report as not ``in session``, however if the transaction is + still in progress, this attribute will still refer to that session. + Only when the transaction is completed does the object become + fully detached under normal circumstances. + + """ return sessionlib._state_session(self) @property @@ -219,11 +249,25 @@ class InstanceState(interfaces.InspectionAttr): del self.obj def _cleanup(self, ref): + """Weakref callback cleanup. + + This callable cleans out the state when it is being garbage + collected. + + this _cleanup **assumes** that there are no strong refs to us! + Will not work otherwise! + + """ instance_dict = self._instance_dict() if instance_dict is not None: - instance_dict.discard(self) + instance_dict._fast_discard(self) + del self._instance_dict + + # we can't possibly be in instance_dict._modified + # b.c. this is weakref cleanup only, that set + # is strong referencing! + # assert self not in instance_dict._modified - self.callables.clear() self.session_id = self._strong_obj = None del self.obj @@ -278,7 +322,7 @@ class InstanceState(interfaces.InspectionAttr): (k, self.__dict__[k]) for k in ( 'committed_state', '_pending_mutations', 'modified', 'expired', 'callables', 'key', 'parents', 'load_options', - 'class_', + 'class_', 'expired_attributes' ) if k in self.__dict__ ) if self.load_path: @@ -305,7 +349,18 @@ class InstanceState(interfaces.InspectionAttr): self.parents = state_dict.get('parents', {}) self.modified = state_dict.get('modified', False) self.expired = state_dict.get('expired', False) - self.callables = state_dict.get('callables', {}) + if 'callables' in state_dict: + self.callables = state_dict['callables'] + + try: + self.expired_attributes = state_dict['expired_attributes'] + except KeyError: + self.expired_attributes = set() + # 0.9 and earlier compat + for k in list(self.callables): + if self.callables[k] is self: + self.expired_attributes.add(k) + del self.callables[k] self.__dict__.update([ (k, state_dict[k]) for k in ( @@ -332,57 +387,73 @@ class InstanceState(interfaces.InspectionAttr): old = dict_.pop(key, None) if old is not None and self.manager[key].impl.collection: self.manager[key].impl._invalidate_collection(old) - self.callables.pop(key, None) + self.expired_attributes.discard(key) + if self.callables: + self.callables.pop(key, None) @classmethod - def _row_processor(cls, manager, fn, key): + def _instance_level_callable_processor(cls, manager, fn, key): impl = manager[key].impl if impl.collection: def _set_callable(state, dict_, row): + if 'callables' not in state.__dict__: + state.callables = {} old = dict_.pop(key, None) if old is not None: impl._invalidate_collection(old) state.callables[key] = fn else: def _set_callable(state, dict_, row): + if 'callables' not in state.__dict__: + state.callables = {} state.callables[key] = fn return _set_callable def _expire(self, dict_, modified_set): self.expired = True + if self.modified: modified_set.discard(self) + self.committed_state.clear() + self.modified = False - self.modified = False self._strong_obj = None - self.committed_state.clear() + if '_pending_mutations' in self.__dict__: + del self.__dict__['_pending_mutations'] - InstanceState._pending_mutations._reset(self) + if 'parents' in self.__dict__: + del self.__dict__['parents'] - # clear out 'parents' collection. not - # entirely clear how we can best determine - # which to remove, or not. - InstanceState.parents._reset(self) + self.expired_attributes.update( + [impl.key for impl in self.manager._scalar_loader_impls + if impl.expire_missing or impl.key in dict_] + ) - for key in self.manager: - impl = self.manager[key].impl - if impl.accepts_scalar_loader and \ - (impl.expire_missing or key in dict_): - self.callables[key] = self - old = dict_.pop(key, None) - if impl.collection and old is not None: - impl._invalidate_collection(old) + if self.callables: + for k in self.expired_attributes.intersection(self.callables): + del self.callables[k] + + for k in self.manager._collection_impl_keys.intersection(dict_): + collection = dict_.pop(k) + collection._sa_adapter.invalidated = True + + for key in self.manager._all_key_set.intersection(dict_): + del dict_[key] self.manager.dispatch.expire(self, None) def _expire_attributes(self, dict_, attribute_names): pending = self.__dict__.get('_pending_mutations', None) + callables = self.callables + for key in attribute_names: impl = self.manager[key].impl if impl.accepts_scalar_loader: - self.callables[key] = self + self.expired_attributes.add(key) + if callables and key in callables: + del callables[key] old = dict_.pop(key, None) if impl.collection and old is not None: impl._invalidate_collection(old) @@ -393,7 +464,7 @@ class InstanceState(interfaces.InspectionAttr): self.manager.dispatch.expire(self, attribute_names) - def __call__(self, state, passive): + def _load_expired(self, state, passive): """__call__ allows the InstanceState to act as a deferred callable for loading expired attributes, which is also serializable (picklable). @@ -412,8 +483,7 @@ class InstanceState(interfaces.InspectionAttr): # instance state didn't have an identity, # the attributes still might be in the callables # dict. ensure they are removed. - for k in toload.intersection(self.callables): - del self.callables[k] + self.expired_attributes.clear() return ATTR_WAS_SET @@ -448,18 +518,6 @@ class InstanceState(interfaces.InspectionAttr): if self.manager[attr].impl.accepts_scalar_loader ) - @property - def expired_attributes(self): - """Return the set of keys which are 'expired' to be loaded by - the manager's deferred scalar loader, assuming no pending - changes. - - see also the ``unmodified`` collection which is intersected - against this set when a refresh operation occurs. - - """ - return set([k for k, v in self.callables.items() if v is self]) - def _instance_dict(self): return None @@ -482,6 +540,7 @@ class InstanceState(interfaces.InspectionAttr): if (self.session_id and self._strong_obj is None) \ or not self.modified: + self.modified = True instance_dict = self._instance_dict() if instance_dict: instance_dict._modified.add(self) @@ -502,7 +561,6 @@ class InstanceState(interfaces.InspectionAttr): self.manager[attr.key], base.state_class_str(self) )) - self.modified = True def _commit(self, dict_, keys): """Commit attributes. @@ -519,10 +577,18 @@ class InstanceState(interfaces.InspectionAttr): self.expired = False - for key in set(self.callables).\ + self.expired_attributes.difference_update( + set(keys).intersection(dict_)) + + # the per-keys commit removes object-level callables, + # while that of commit_all does not. it's not clear + # if this behavior has a clear rationale, however tests do + # ensure this is what it does. + if self.callables: + for key in set(self.callables).\ intersection(keys).\ - intersection(dict_): - del self.callables[key] + intersection(dict_): + del self.callables[key] def _commit_all(self, dict_, instance_dict=None): """commit all attributes unconditionally. @@ -533,7 +599,8 @@ class InstanceState(interfaces.InspectionAttr): - all attributes are marked as "committed" - the "strong dirty reference" is removed - the "modified" flag is set to False - - any "expired" markers/callables for attributes loaded are removed. + - any "expired" markers for scalar attributes loaded are removed. + - lazy load callables for objects / collections *stay* Attributes marked as "expired" can potentially remain "expired" after this step if a value was not populated in state.dict. @@ -553,10 +620,7 @@ class InstanceState(interfaces.InspectionAttr): if '_pending_mutations' in state_dict: del state_dict['_pending_mutations'] - callables = state.callables - for key in list(callables): - if key in dict_ and callables[key] is state: - del callables[key] + state.expired_attributes.difference_update(dict_) if instance_dict and state.modified: instance_dict._modified.discard(state) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index cdb501c14..c03e133de 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1,5 +1,5 @@ # orm/strategies.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -22,6 +22,7 @@ from . import properties from .interfaces import ( LoaderStrategy, StrategizedProperty ) +from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE from .session import _state_session import itertools @@ -105,6 +106,8 @@ class UninstrumentedColumnLoader(LoaderStrategy): if the argument is against the with_polymorphic selectable. """ + __slots__ = 'columns', + def __init__(self, parent): super(UninstrumentedColumnLoader, self).__init__(parent) self.columns = self.parent_property.columns @@ -128,6 +131,8 @@ class UninstrumentedColumnLoader(LoaderStrategy): class ColumnLoader(LoaderStrategy): """Provide loading behavior for a :class:`.ColumnProperty`.""" + __slots__ = 'columns', 'is_composite' + def __init__(self, parent): super(ColumnLoader, self).__init__(parent) self.columns = self.parent_property.columns @@ -135,12 +140,18 @@ class ColumnLoader(LoaderStrategy): def setup_query( self, context, entity, path, loadopt, - adapter, column_collection, **kwargs): + adapter, column_collection, memoized_populators, **kwargs): + for c in self.columns: if adapter: c = adapter.columns[c] column_collection.append(c) + fetch = self.columns[0] + if adapter: + fetch = adapter.columns[fetch] + memoized_populators[self.parent_property] = fetch + def init_class_attribute(self, mapper): self.is_class_level = True coltype = self.columns[0].type @@ -176,6 +187,8 @@ class ColumnLoader(LoaderStrategy): class DeferredColumnLoader(LoaderStrategy): """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" + __slots__ = 'columns', 'group' + def __init__(self, parent): super(DeferredColumnLoader, self).__init__(parent) if hasattr(self.parent_property, 'composite_class'): @@ -187,22 +200,14 @@ class DeferredColumnLoader(LoaderStrategy): def create_row_processor( self, context, path, loadopt, mapper, result, adapter, populators): - col = self.columns[0] - if adapter: - col = adapter.columns[col] - # TODO: put a result-level contains here - getter = result._getter(col) - if getter: - self.parent_property._get_strategy_by_cls(ColumnLoader).\ - create_row_processor( - context, path, loadopt, mapper, result, - adapter, populators) - - elif not self.is_class_level: - set_deferred_for_local_state = InstanceState._row_processor( - mapper.class_manager, - LoadDeferredColumns(self.key), self.key) + # this path currently does not check the result + # for the column; this is because in most cases we are + # working just with the setup_query() directive which does + # not support this, and the behavior here should be consistent. + if not self.is_class_level: + set_deferred_for_local_state = \ + self.parent_property._deferred_column_loader populators["new"].append((self.key, set_deferred_for_local_state)) else: populators["expire"].append((self.key, False)) @@ -218,14 +223,16 @@ class DeferredColumnLoader(LoaderStrategy): ) def setup_query( - self, context, entity, path, loadopt, adapter, - only_load_props=None, **kwargs): + self, context, entity, path, loadopt, + adapter, column_collection, memoized_populators, + only_load_props=None, **kw): if ( ( loadopt and 'undefer_pks' in loadopt.local_opts and - set(self.columns).intersection(self.parent.primary_key) + set(self.columns).intersection( + self.parent._should_undefer_in_wildcard) ) or ( @@ -240,7 +247,12 @@ class DeferredColumnLoader(LoaderStrategy): ): self.parent_property._get_strategy_by_cls(ColumnLoader).\ setup_query(context, entity, - path, loadopt, adapter, **kwargs) + path, loadopt, adapter, + column_collection, memoized_populators, **kw) + elif self.is_class_level: + memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED + else: + memoized_populators[self.parent_property] = _DEFER_FOR_STATE def _load_for_state(self, state, passive): if not state.key: @@ -300,6 +312,8 @@ class LoadDeferredColumns(object): class AbstractRelationshipLoader(LoaderStrategy): """LoaderStratgies which deal with related objects.""" + __slots__ = 'mapper', 'target', 'uselist' + def __init__(self, parent): super(AbstractRelationshipLoader, self).__init__(parent) self.mapper = self.parent_property.mapper @@ -316,6 +330,8 @@ class NoLoader(AbstractRelationshipLoader): """ + __slots__ = () + def init_class_attribute(self, mapper): self.is_class_level = True @@ -337,12 +353,16 @@ class NoLoader(AbstractRelationshipLoader): @log.class_logger @properties.RelationshipProperty.strategy_for(lazy=True) @properties.RelationshipProperty.strategy_for(lazy="select") -class LazyLoader(AbstractRelationshipLoader): +class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): """Provide loading behavior for a :class:`.RelationshipProperty` with "lazy=True", that is loads when first accessed. """ + __slots__ = ( + '_lazywhere', '_rev_lazywhere', 'use_get', '_bind_to_col', + '_equated_columns', '_rev_bind_to_col', '_rev_equated_columns') + def __init__(self, parent): super(LazyLoader, self).__init__(parent) join_condition = self.parent_property._join_condition @@ -373,7 +393,7 @@ class LazyLoader(AbstractRelationshipLoader): self._equated_columns[c] = self._equated_columns[col] self.logger.info("%s will use query.get() to " - "optimize instance loads" % self) + "optimize instance loads", self) def init_class_attribute(self, mapper): self.is_class_level = True @@ -401,78 +421,54 @@ class LazyLoader(AbstractRelationshipLoader): active_history=active_history ) - def lazy_clause( - self, state, reverse_direction=False, - alias_secondary=False, - adapt_source=None, - passive=None): - if state is None: - return self._lazy_none_clause( - reverse_direction, - adapt_source=adapt_source) - - if not reverse_direction: - criterion, bind_to_col = \ - self._lazywhere, \ - self._bind_to_col - else: - criterion, bind_to_col = \ - self._rev_lazywhere, \ - self._rev_bind_to_col + def _memoized_attr__simple_lazy_clause(self): + criterion, bind_to_col = ( + self._lazywhere, + self._bind_to_col + ) - if reverse_direction: - mapper = self.parent_property.mapper - else: - mapper = self.parent_property.parent + params = [] - o = state.obj() # strong ref - dict_ = attributes.instance_dict(o) + def visit_bindparam(bindparam): + bindparam.unique = False + if bindparam._identifying_key in bind_to_col: + params.append(( + bindparam.key, bind_to_col[bindparam._identifying_key], + None)) + else: + params.append((bindparam.key, None, bindparam.value)) - # use the "committed state" only if we're in a flush - # for this state. + criterion = visitors.cloned_traverse( + criterion, {}, {'bindparam': visit_bindparam} + ) - if passive and passive & attributes.LOAD_AGAINST_COMMITTED: - def visit_bindparam(bindparam): - if bindparam._identifying_key in bind_to_col: - bindparam.callable = \ - lambda: mapper._get_committed_state_attr_by_column( - state, dict_, - bind_to_col[bindparam._identifying_key]) - else: - def visit_bindparam(bindparam): - if bindparam._identifying_key in bind_to_col: - bindparam.callable = \ - lambda: mapper._get_state_attr_by_column( - state, dict_, - bind_to_col[bindparam._identifying_key]) - - if self.parent_property.secondary is not None and alias_secondary: - criterion = sql_util.ClauseAdapter( - self.parent_property.secondary.alias()).\ - traverse(criterion) + return criterion, params - criterion = visitors.cloned_traverse( - criterion, {}, {'bindparam': visit_bindparam}) + def _generate_lazy_clause(self, state, passive): + criterion, param_keys = self._simple_lazy_clause - if adapt_source: - criterion = adapt_source(criterion) - return criterion + if state is None: + return sql_util.adapt_criterion_to_null( + criterion, [key for key, ident, value in param_keys]) - def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): - if not reverse_direction: - criterion, bind_to_col = \ - self._lazywhere, \ - self._bind_to_col - else: - criterion, bind_to_col = \ - self._rev_lazywhere, \ - self._rev_bind_to_col + mapper = self.parent_property.parent - criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col) + o = state.obj() # strong ref + dict_ = attributes.instance_dict(o) - if adapt_source: - criterion = adapt_source(criterion) - return criterion + params = {} + for key, ident, value in param_keys: + if ident is not None: + if passive and passive & attributes.LOAD_AGAINST_COMMITTED: + value = mapper._get_committed_state_attr_by_column( + state, dict_, ident) + else: + value = mapper._get_state_attr_by_column( + state, dict_, ident) + + params[key] = value + + return criterion, params def _load_for_state(self, state, passive): if not state.key and ( @@ -549,10 +545,9 @@ class LazyLoader(AbstractRelationshipLoader): @util.dependencies("sqlalchemy.orm.strategy_options") def _emit_lazyload( - self, strategy_options, session, state, - ident_key, passive): - q = session.query(self.mapper)._adapt_all_clauses() + self, strategy_options, session, state, ident_key, passive): + q = session.query(self.mapper)._adapt_all_clauses() if self.parent_property.secondary is not None: q = q.select_from(self.mapper, self.parent_property.secondary) @@ -583,17 +578,15 @@ class LazyLoader(AbstractRelationshipLoader): rev._use_get and \ not isinstance(rev.strategy, LazyLoader): q = q.options( - strategy_options.Load(rev.parent). - lazyload(rev.key)) + strategy_options.Load(rev.parent).lazyload(rev.key)) - lazy_clause = self.lazy_clause(state, passive=passive) + lazy_clause, params = self._generate_lazy_clause( + state, passive=passive) - if pending: - bind_values = sql_util.bind_values(lazy_clause) - if orm_util._none_set.intersection(bind_values): - return None + if pending and orm_util._none_set.intersection(params.values()): + return None - q = q.filter(lazy_clause) + q = q.filter(lazy_clause).params(params) result = q.all() if self.uselist: @@ -624,9 +617,9 @@ class LazyLoader(AbstractRelationshipLoader): # "lazyload" option on a "no load" # attribute - "eager" attributes always have a # class-level lazyloader installed. - set_lazy_callable = InstanceState._row_processor( + set_lazy_callable = InstanceState._instance_level_callable_processor( mapper.class_manager, - LoadLazyAttribute(key), key) + LoadLazyAttribute(key, self._strategy_keys[0]), key) populators["new"].append((self.key, set_lazy_callable)) elif context.populate_existing or mapper.always_refresh: @@ -647,20 +640,23 @@ class LazyLoader(AbstractRelationshipLoader): class LoadLazyAttribute(object): """serializable loader object used by LazyLoader""" - def __init__(self, key): + def __init__(self, key, strategy_key=(('lazy', 'select'),)): self.key = key + self.strategy_key = strategy_key def __call__(self, state, passive=attributes.PASSIVE_OFF): key = self.key instance_mapper = state.manager.mapper prop = instance_mapper._props[key] - strategy = prop._strategies[LazyLoader] + strategy = prop._strategies[self.strategy_key] return strategy._load_for_state(state, passive) @properties.RelationshipProperty.strategy_for(lazy="immediate") class ImmediateLoader(AbstractRelationshipLoader): + __slots__ = () + def init_class_attribute(self, mapper): self.parent_property.\ _get_strategy_by_cls(LazyLoader).\ @@ -684,6 +680,8 @@ class ImmediateLoader(AbstractRelationshipLoader): @log.class_logger @properties.RelationshipProperty.strategy_for(lazy="subquery") class SubqueryLoader(AbstractRelationshipLoader): + __slots__ = 'join_depth', + def __init__(self, parent): super(SubqueryLoader, self).__init__(parent) self.join_depth = self.parent_property.join_depth @@ -1005,6 +1003,12 @@ class SubqueryLoader(AbstractRelationshipLoader): if subq is None: return + assert subq.session is context.session, ( + "Subquery session doesn't refer to that of " + "our context. Are there broken context caching " + "schemes being used?" + ) + local_cols = self.parent_property.local_columns # cache the loaded collections in the context @@ -1069,6 +1073,9 @@ class JoinedLoader(AbstractRelationshipLoader): using joined eager loading. """ + + __slots__ = 'join_depth', + def __init__(self, parent): super(JoinedLoader, self).__init__(parent) self.join_depth = self.parent_property.join_depth @@ -1130,16 +1137,12 @@ class JoinedLoader(AbstractRelationshipLoader): path = path[self.mapper] - for value in self.mapper._iterate_polymorphic_properties( - mappers=with_polymorphic): - value.setup( - context, - entity, - path, - clauses, - parentmapper=self.mapper, - column_collection=add_to_collection, - chained_from_outerjoin=chained_from_outerjoin) + loading._setup_entity_query( + context, self.mapper, entity, + path, clauses, add_to_collection, + with_polymorphic=with_polymorphic, + parentmapper=self.mapper, + chained_from_outerjoin=chained_from_outerjoin) if with_poly_info is not None and \ None in set(context.secondary_columns): @@ -1246,7 +1249,7 @@ class JoinedLoader(AbstractRelationshipLoader): anonymize_labels=True) assert clauses.aliased_class is not None - if self.parent_property.direction != interfaces.MANYTOONE: + if self.parent_property.uselist: context.multi_row_eager_loaders = True innerjoin = ( @@ -1329,34 +1332,25 @@ class JoinedLoader(AbstractRelationshipLoader): assert clauses.aliased_class is not None - join_to_outer = innerjoin and isinstance(towrap, sql.Join) and \ - towrap.isouter - - if chained_from_outerjoin and \ - join_to_outer and innerjoin != 'unnested': - inner = orm_util.join( - towrap.right, - clauses.aliased_class, - onclause, - isouter=False - ) + attach_on_outside = ( + not chained_from_outerjoin or + not innerjoin or innerjoin == 'unnested') - eagerjoin = orm_util.join( - towrap.left, - inner, - towrap.onclause, - isouter=True - ) - eagerjoin._target_adapter = inner._target_adapter - else: - if chained_from_outerjoin: - innerjoin = False - eagerjoin = orm_util.join( + if attach_on_outside: + # this is the "classic" eager join case. + eagerjoin = orm_util._ORMJoin( towrap, clauses.aliased_class, onclause, - isouter=not innerjoin + isouter=not innerjoin or ( + chained_from_outerjoin and isinstance(towrap, sql.Join) + ), _left_memo=self.parent, _right_memo=self.mapper ) + else: + # all other cases are innerjoin=='nested' approach + eagerjoin = self._splice_nested_inner_join( + path, towrap, clauses, onclause) + context.eager_joins[entity_key] = eagerjoin # send a hint to the Query as to where it may "splice" this join @@ -1386,6 +1380,66 @@ class JoinedLoader(AbstractRelationshipLoader): ) ) + def _splice_nested_inner_join( + self, path, join_obj, clauses, onclause, splicing=False): + + if splicing is False: + # first call is always handed a join object + # from the outside + assert isinstance(join_obj, orm_util._ORMJoin) + elif isinstance(join_obj, sql.selectable.FromGrouping): + return self._splice_nested_inner_join( + path, join_obj.element, clauses, onclause, splicing + ) + elif not isinstance(join_obj, orm_util._ORMJoin): + if path[-2] is splicing: + return orm_util._ORMJoin( + join_obj, clauses.aliased_class, + onclause, isouter=False, + _left_memo=splicing, + _right_memo=path[-1].mapper + ) + else: + # only here if splicing == True + return None + + target_join = self._splice_nested_inner_join( + path, join_obj.right, clauses, + onclause, join_obj._right_memo) + if target_join is None: + right_splice = False + target_join = self._splice_nested_inner_join( + path, join_obj.left, clauses, + onclause, join_obj._left_memo) + if target_join is None: + # should only return None when recursively called, + # e.g. splicing==True + assert splicing is not False, \ + "assertion failed attempting to produce joined eager loads" + return None + else: + right_splice = True + + if right_splice: + # for a right splice, attempt to flatten out + # a JOIN b JOIN c JOIN .. to avoid needless + # parenthesis nesting + if not join_obj.isouter and not target_join.isouter: + eagerjoin = join_obj._splice_into_center(target_join) + else: + eagerjoin = orm_util._ORMJoin( + join_obj.left, target_join, + join_obj.onclause, isouter=join_obj.isouter, + _left_memo=join_obj._left_memo) + else: + eagerjoin = orm_util._ORMJoin( + target_join, join_obj.right, + join_obj.onclause, isouter=join_obj.isouter, + _right_memo=join_obj._right_memo) + + eagerjoin._target_adapter = target_join._target_adapter + return eagerjoin + def _create_eager_adapter(self, context, result, adapter, path, loadopt): user_defined_adapter = self._init_user_defined_eager_proc( loadopt, context) if loadopt else False @@ -1431,7 +1485,7 @@ class JoinedLoader(AbstractRelationshipLoader): if eager_adapter is not False: key = self.key - _instance = loading.instance_processor( + _instance = loading._instance_processor( self.mapper, context, result, diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 4f986193e..cb7a5fef7 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1,4 +1,4 @@ -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -161,11 +161,14 @@ class Load(Generative, MapperOption): ext_info = inspect(ac) path_element = ext_info.mapper + existing = path.entity_path[prop].get( + self.context, "path_with_polymorphic") if not ext_info.is_aliased_class: ac = orm_util.with_polymorphic( ext_info.mapper.base_mapper, ext_info.mapper, aliased=True, - _use_mapper_path=True) + _use_mapper_path=True, + _existing_alias=existing) path.entity_path[prop].set( self.context, "path_with_polymorphic", inspect(ac)) path = path[prop][path_element] @@ -176,6 +179,9 @@ class Load(Generative, MapperOption): path = path.entity_path return path + def __str__(self): + return "Load(strategy=%r)" % self.strategy + def _coerce_strat(self, strategy): if strategy is not None: strategy = tuple(sorted(strategy.items())) @@ -358,6 +364,7 @@ class _UnboundLoad(Load): return None token = start_path[0] + if isinstance(token, util.string_types): entity = self._find_entity_basestring(query, token, raiseerr) elif isinstance(token, PropComparator): @@ -401,10 +408,18 @@ class _UnboundLoad(Load): # prioritize "first class" options over those # that were "links in the chain", e.g. "x" and "y" in # someload("x.y.z") versus someload("x") / someload("x.y") - if self._is_chain_link: - effective_path.setdefault(context, "loader", loader) + + if effective_path.is_token: + for path in effective_path.generate_for_superclasses(): + if self._is_chain_link: + path.setdefault(context, "loader", loader) + else: + path.set(context, "loader", loader) else: - effective_path.set(context, "loader", loader) + if self._is_chain_link: + effective_path.setdefault(context, "loader", loader) + else: + effective_path.set(context, "loader", loader) def _find_entity_prop_comparator(self, query, token, mapper, raiseerr): if _is_aliased_class(mapper): diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index e1ef85c1d..e9a745cc0 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -1,5 +1,5 @@ # orm/sync.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -45,6 +45,23 @@ def populate(source, source_mapper, dest, dest_mapper, uowcommit.attributes[("pk_cascaded", dest, r)] = True +def bulk_populate_inherit_keys( + source_dict, source_mapper, synchronize_pairs): + # a simplified version of populate() used by bulk insert mode + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + value = source_dict[prop.key] + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r) + + try: + prop = source_mapper._columntoproperty[r] + source_dict[prop.key] = value + except exc.UnmappedColumnError: + _raise_col_to_prop(True, source_mapper, l, source_mapper, r) + + def clear(dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: if r.primary_key and \ diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 71e61827b..1ef0d24ca 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -1,5 +1,5 @@ # orm/unitofwork.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -16,6 +16,7 @@ organizes them in order of dependency, and executes. from .. import util, event from ..util import topological from . import attributes, persistence, util as orm_util +import itertools def track_cascade_events(descriptor, prop): @@ -379,14 +380,19 @@ class UOWTransaction(object): execute() method has succeeded and the transaction has been committed. """ + if not self.states: + return + states = set(self.states) isdel = set( s for (s, (isdelete, listonly)) in self.states.items() if isdelete ) other = states.difference(isdel) - self.session._remove_newly_deleted(isdel) - self.session._register_newly_persistent(other) + if isdel: + self.session._remove_newly_deleted(isdel) + if other: + self.session._register_newly_persistent(other) class IterateMappersMixin(object): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 8d40ae21c..b3f3bc5fa 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1,5 +1,5 @@ # orm/util.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -30,21 +30,19 @@ class CascadeOptions(frozenset): 'all', 'none', 'delete-orphan']) _allowed_cascades = all_cascades - def __new__(cls, arg): - values = set([ - c for c - in re.split('\s*,\s*', arg or "") - if c - ]) + __slots__ = ( + 'save_update', 'delete', 'refresh_expire', 'merge', + 'expunge', 'delete_orphan') + def __new__(cls, value_list): + if isinstance(value_list, util.string_types) or value_list is None: + return cls.from_string(value_list) + values = set(value_list) if values.difference(cls._allowed_cascades): raise sa_exc.ArgumentError( "Invalid cascade option(s): %s" % ", ".join([repr(x) for x in - sorted( - values.difference(cls._allowed_cascades) - )]) - ) + sorted(values.difference(cls._allowed_cascades))])) if "all" in values: values.update(cls._add_w_all_cascades) @@ -70,6 +68,15 @@ class CascadeOptions(frozenset): ",".join([x for x in sorted(self)]) ) + @classmethod + def from_string(cls, arg): + values = [ + c for c + in re.split('\s*,\s*', arg or "") + if c + ] + return cls(values) + def _validator_events( desc, key, validator, include_removes, include_backrefs): @@ -538,8 +545,13 @@ class AliasedInsp(InspectionAttr): mapper, self) def __repr__(self): - return '<AliasedInsp at 0x%x; %s>' % ( - id(self), self.class_.__name__) + if self.with_polymorphic_mappers: + with_poly = "(%s)" % ", ".join( + mp.class_.__name__ for mp in self.with_polymorphic_mappers) + else: + with_poly = "" + return '<AliasedInsp at 0x%x; %s%s>' % ( + id(self), self.class_.__name__, with_poly) inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) @@ -643,7 +655,8 @@ def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): def with_polymorphic(base, classes, selectable=False, flat=False, polymorphic_on=None, aliased=False, - innerjoin=False, _use_mapper_path=False): + innerjoin=False, _use_mapper_path=False, + _existing_alias=None): """Produce an :class:`.AliasedClass` construct which specifies columns for descendant mappers of the given base. @@ -708,6 +721,16 @@ def with_polymorphic(base, classes, selectable=False, only be specified if querying for one specific subtype only """ primary_mapper = _class_to_mapper(base) + if _existing_alias: + assert _existing_alias.mapper is primary_mapper + classes = util.to_set(classes) + new_classes = set([ + mp.class_ for mp in + _existing_alias.with_polymorphic_mappers]) + if classes == new_classes: + return _existing_alias + else: + classes = classes.union(new_classes) mappers, selectable = primary_mapper.\ _with_polymorphic_args(classes, selectable, innerjoin=innerjoin) @@ -753,7 +776,10 @@ class _ORMJoin(expression.Join): __visit_name__ = expression.Join.__visit_name__ - def __init__(self, left, right, onclause=None, isouter=False): + def __init__( + self, + left, right, onclause=None, isouter=False, + _left_memo=None, _right_memo=None): left_info = inspection.inspect(left) left_orm_info = getattr(left, '_joined_from_info', left_info) @@ -763,6 +789,9 @@ class _ORMJoin(expression.Join): self._joined_from_info = right_info + self._left_memo = _left_memo + self._right_memo = _right_memo + if isinstance(onclause, util.string_types): onclause = getattr(left_orm_info.entity, onclause) @@ -814,6 +843,28 @@ class _ORMJoin(expression.Join): single_crit = right_info._adapter.traverse(single_crit) self.onclause = self.onclause & single_crit + def _splice_into_center(self, other): + """Splice a join into the center. + + Given join(a, b) and join(b, c), return join(a, b).join(c) + + """ + assert self.right is other.left + + left = _ORMJoin( + self.left, other.left, + self.onclause, isouter=self.isouter, + _left_memo=self._left_memo, + _right_memo=other._left_memo + ) + + return _ORMJoin( + left, + other.right, + other.onclause, isouter=other.isouter, + _right_memo=other._right_memo + ) + def join(self, right, onclause=None, isouter=False, join_to_left=None): return _ORMJoin(self, right, onclause, isouter) diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index a174df784..ccb4f1e6a 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -1,5 +1,5 @@ # sqlalchemy/pool.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -230,6 +230,7 @@ class Pool(log.Identified): % reset_on_return) self.echo = echo + if _dispatch: self.dispatch._update(_dispatch, only_propagate=False) if _dialect: @@ -528,6 +529,7 @@ class _ConnectionRecord(object): return self.connection def __close(self): + self.finalize_callback.clear() self.__pool._close_connection(self.connection) def __connect(self): @@ -917,9 +919,9 @@ class QueuePool(Pool): on returning a connection. Defaults to 30. :param \**kw: Other keyword arguments including - :paramref:`.Pool.recycle`, :paramref:`.Pool.echo`, - :paramref:`.Pool.reset_on_return` and others are passed to the - :class:`.Pool` constructor. + :paramref:`.Pool.recycle`, :paramref:`.Pool.echo`, + :paramref:`.Pool.reset_on_return` and others are passed to the + :class:`.Pool` constructor. """ Pool.__init__(self, creator, **kw) diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py index 3794b01f5..6575fad17 100644 --- a/lib/sqlalchemy/processors.py +++ b/lib/sqlalchemy/processors.py @@ -1,5 +1,5 @@ # sqlalchemy/processors.py -# Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2010-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 4b6ad1988..327498fc5 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -1,5 +1,5 @@ # schema.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -35,6 +35,7 @@ from .sql.schema import ( UniqueConstraint, _get_table_key, ColumnCollectionConstraint, + ColumnCollectionMixin ) @@ -58,5 +59,7 @@ from .sql.ddl import ( DDLBase, DDLElement, _CreateDropBase, - _DDLCompiles + _DDLCompiles, + sort_tables, + sort_tables_and_constraints ) diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 351e08d0b..e8b70061d 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -1,5 +1,5 @@ # sql/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 3df4257d4..8fec5039b 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -1,5 +1,5 @@ # sql/annotation.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 2d06109b9..eed079238 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1,5 +1,5 @@ # sql/base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -449,10 +449,12 @@ class ColumnCollection(util.OrderedProperties): """ + __slots__ = '_all_col_set', '_all_columns' + def __init__(self, *columns): super(ColumnCollection, self).__init__() - self.__dict__['_all_col_set'] = util.column_set() - self.__dict__['_all_columns'] = [] + object.__setattr__(self, '_all_col_set', util.column_set()) + object.__setattr__(self, '_all_columns', []) for c in columns: self.add(c) @@ -576,13 +578,14 @@ class ColumnCollection(util.OrderedProperties): return util.OrderedProperties.__contains__(self, other) def __getstate__(self): - return {'_data': self.__dict__['_data'], - '_all_columns': self.__dict__['_all_columns']} + return {'_data': self._data, + '_all_columns': self._all_columns} def __setstate__(self, state): - self.__dict__['_data'] = state['_data'] - self.__dict__['_all_columns'] = state['_all_columns'] - self.__dict__['_all_col_set'] = util.column_set(state['_all_columns']) + object.__setattr__(self, '_data', state['_data']) + object.__setattr__(self, '_all_columns', state['_all_columns']) + object.__setattr__( + self, '_all_col_set', util.column_set(state['_all_columns'])) def contains_column(self, col): # this has to be done via set() membership @@ -596,8 +599,8 @@ class ColumnCollection(util.OrderedProperties): class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): def __init__(self, data, colset, all_columns): util.ImmutableProperties.__init__(self, data) - self.__dict__['_all_col_set'] = colset - self.__dict__['_all_columns'] = all_columns + object.__setattr__(self, '_all_col_set', colset) + object.__setattr__(self, '_all_columns', all_columns) extend = remove = util.ImmutableProperties._immutable diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a6c30b7dc..755193552 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1,5 +1,5 @@ # sql/compiler.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -23,6 +23,7 @@ To generate user-defined SQL strings, see """ +import contextlib import re from . import schema, sqltypes, operators, functions, visitors, \ elements, selectable, crud @@ -82,6 +83,7 @@ OPERATORS = { operators.eq: ' = ', operators.concat_op: ' || ', operators.match_op: ' MATCH ', + operators.notmatch_op: ' NOT MATCH ', operators.in_op: ' IN ', operators.notin_op: ' NOT IN ', operators.comma_op: ', ', @@ -247,15 +249,16 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() -class TypeCompiler(object): - +class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" + ensure_kwarg = 'visit_\w+' + def __init__(self, dialect): self.dialect = dialect - def process(self, type_): - return type_._compiler_dispatch(self) + def process(self, type_, **kw): + return type_._compiler_dispatch(self, **kw) class _CompileLabel(visitors.Visitable): @@ -359,7 +362,12 @@ class SQLCompiler(Compiled): # column/label name, ColumnElement object (if any) and # TypeEngine. ResultProxy uses this for type processing and # column targeting - self.result_map = {} + self._result_columns = [] + + # if False, means we can't be sure the list of entries + # in _result_columns is actually the rendered order. This + # gets flipped when we use TextAsFrom, for example. + self._ordered_columns = True # true if the paramstyle is positional self.positional = dialect.positional @@ -400,6 +408,26 @@ class SQLCompiler(Compiled): if self.positional: self.cte_positional = {} + @contextlib.contextmanager + def _nested_result(self): + """special API to support the use case of 'nested result sets'""" + result_columns, ordered_columns = ( + self._result_columns, self._ordered_columns) + self._result_columns, self._ordered_columns = [], False + + try: + if self.stack: + entry = self.stack[-1] + entry['need_result_map_for_nested'] = True + else: + entry = None + yield self._result_columns, self._ordered_columns + finally: + if entry: + entry.pop('need_result_map_for_nested') + self._result_columns, self._ordered_columns = ( + result_columns, ordered_columns) + def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( @@ -478,6 +506,11 @@ class SQLCompiler(Compiled): compiled object, for those values that are present.""" return self.construct_params(_check=False) + @util.dependencies("sqlalchemy.engine.result") + def _create_result_map(self, result): + """utility method used for unit tests only.""" + return result.ResultMetaData._create_result_map(self._result_columns) + def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -527,7 +560,6 @@ class SQLCompiler(Compiled): selectable = self.stack[-1]['selectable'] with_cols, only_froms = selectable._label_resolve_dict - try: if within_columns_clause: col = only_froms[element.element] @@ -637,8 +669,9 @@ class SQLCompiler(Compiled): def visit_index(self, index, **kwargs): return index.name - def visit_typeclause(self, typeclause, **kwargs): - return self.dialect.type_compiler.process(typeclause.type) + def visit_typeclause(self, typeclause, **kw): + kw['type_expression'] = typeclause + return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): return text @@ -659,22 +692,22 @@ class SQLCompiler(Compiled): self.post_process_text(textclause.text)) ) - def visit_text_as_from(self, taf, iswrapper=False, - compound_index=0, force_result_map=False, + def visit_text_as_from(self, taf, + compound_index=None, asfrom=False, parens=True, **kw): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = force_result_map or ( - compound_index == 0 and ( - toplevel or - entry['iswrapper'] - ) - ) + populate_result_map = toplevel or \ + ( + compound_index == 0 and entry.get( + 'need_result_map_for_compound', False) + ) or entry.get('need_result_map_for_nested', False) if populate_result_map: + self._ordered_columns = False for c in taf.column_args: self.process(c, within_columns_clause=True, add_to_result_map=self._add_to_result_map) @@ -787,13 +820,16 @@ class SQLCompiler(Compiled): parens=True, compound_index=0, **kwargs): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] + need_result_map = toplevel or \ + (compound_index == 0 + and entry.get('need_result_map_for_compound', False)) self.stack.append( { 'correlate_froms': entry['correlate_froms'], - 'iswrapper': toplevel, 'asfrom_froms': entry['asfrom_froms'], - 'selectable': cs + 'selectable': cs, + 'need_result_map_for_compound': need_result_map }) keyword = self.compound_keywords.get(cs.keyword) @@ -813,10 +849,9 @@ class SQLCompiler(Compiled): text += self.order_by_clause(cs, **kwargs) text += (cs._limit_clause is not None or cs._offset_clause is not None) and \ - self.limit_clause(cs) or "" + self.limit_clause(cs, **kwargs) or "" - if self.ctes and \ - compound_index == 0 and toplevel: + if self.ctes and toplevel: text = self._render_cte_clause() + text self.stack.pop(-1) @@ -862,14 +897,18 @@ class SQLCompiler(Compiled): else: return "%s = 0" % self.process(element.element, **kw) - def visit_binary(self, binary, **kw): + def visit_notmatch_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op) + + def visit_binary(self, binary, override_operator=None, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ isinstance(binary.left, elements.BindParameter) and \ isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - operator_ = binary.operator + operator_ = override_operator or binary.operator disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) if disp: return disp(binary, operator_, **kw) @@ -1188,12 +1227,16 @@ class SQLCompiler(Compiled): self, asfrom=True, **kwargs ) + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs) + self.ctes[cte] = text if asfrom: if cte_alias_name: text = self.preparer.format_alias(cte, cte_alias_name) - text += " AS " + cte_name + text += self.get_render_as_alias_suffix(cte_name) else: return self.preparer.format_alias(cte, cte_name) return text @@ -1212,8 +1255,8 @@ class SQLCompiler(Compiled): elif asfrom: ret = alias.original._compiler_dispatch(self, asfrom=True, **kwargs) + \ - " AS " + \ - self.preparer.format_alias(alias, alias_name) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name)) if fromhints and alias in fromhints: ret = self.format_from_hint_text(ret, alias, @@ -1223,19 +1266,14 @@ class SQLCompiler(Compiled): else: return alias.original._compiler_dispatch(self, **kwargs) + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + def _add_to_result_map(self, keyname, name, objects, type_): if not self.dialect.case_sensitive: keyname = keyname.lower() - if keyname in self.result_map: - # conflicting keyname, just double up the list - # of objects. this will cause an "ambiguous name" - # error if an attempt is made by the result set to - # access. - e_name, e_obj, e_type = self.result_map[keyname] - self.result_map[keyname] = e_name, e_obj + objects, e_type - else: - self.result_map[keyname] = name, objects, type_ + self._result_columns.append((keyname, name, objects, type_)) def _label_select_column(self, select, column, populate_result_map, @@ -1425,12 +1463,13 @@ class SQLCompiler(Compiled): (inner_col[c._key_label], c) for c in select.inner_columns ) - for key, (name, objs, typ) in list(self.result_map.items()): - objs = tuple([d.get(col, col) for col in objs]) - self.result_map[key] = (name, objs, typ) + + self._result_columns = [ + (key, name, tuple([d.get(col, col) for col in objs]), typ) + for key, name, objs, typ in self._result_columns + ] _default_stack_entry = util.immutabledict([ - ('iswrapper', False), ('correlate_froms', frozenset()), ('asfrom_froms', frozenset()) ]) @@ -1458,10 +1497,10 @@ class SQLCompiler(Compiled): return froms def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, fromhints=None, + fromhints=None, compound_index=0, - force_result_map=False, nested_join_translation=False, + select_wraps_for=None, **kwargs): needs_nested_translation = \ @@ -1475,21 +1514,19 @@ class SQLCompiler(Compiled): select) text = self.visit_select( transformed_select, asfrom=asfrom, parens=parens, - iswrapper=iswrapper, fromhints=fromhints, + fromhints=fromhints, compound_index=compound_index, - force_result_map=force_result_map, nested_join_translation=True, **kwargs ) toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = force_result_map or ( - compound_index == 0 and ( - toplevel or - entry['iswrapper'] - ) - ) + populate_result_map = toplevel or \ + ( + compound_index == 0 and entry.get( + 'need_result_map_for_compound', False) + ) or entry.get('need_result_map_for_nested', False) if needs_nested_translation: if populate_result_map: @@ -1497,7 +1534,7 @@ class SQLCompiler(Compiled): select, transformed_select) return text - froms = self._setup_select_stack(select, entry, asfrom, iswrapper) + froms = self._setup_select_stack(select, entry, asfrom) column_clause_args = kwargs.copy() column_clause_args.update({ @@ -1523,16 +1560,34 @@ class SQLCompiler(Compiled): # the actual list of columns to print in the SELECT column list. inner_columns = [ c for c in [ - self._label_select_column(select, - column, - populate_result_map, asfrom, - column_clause_args, - name=name) + self._label_select_column( + select, + column, + populate_result_map, asfrom, + column_clause_args, + name=name) for name, column in select._columns_plus_names ] if c is not None ] + if populate_result_map and select_wraps_for is not None: + # if this select is a compiler-generated wrapper, + # rewrite the targeted columns in the result map + wrapped_inner_columns = set(select_wraps_for.inner_columns) + translate = dict( + (outer, inner.pop()) for outer, inner in [ + ( + outer, + outer.proxy_set.intersection(wrapped_inner_columns)) + for outer in select.inner_columns + ] if inner + ) + self._result_columns = [ + (key, name, tuple(translate.get(o, o) for o in obj), type_) + for key, name, obj, type_ in self._result_columns + ] + text = self._compose_select_body( text, select, inner_columns, froms, byfrom, kwargs) @@ -1545,10 +1600,13 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and \ - compound_index == 0 and toplevel: + if self.ctes and toplevel: text = self._render_cte_clause() + text + if select._suffixes: + text += " " + self._generate_prefixes( + select, select._suffixes, **kwargs) + self.stack.pop(-1) if asfrom and parens: @@ -1569,7 +1627,7 @@ class SQLCompiler(Compiled): hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom - def _setup_select_stack(self, select, entry, asfrom, iswrapper): + def _setup_select_stack(self, select, entry, asfrom): correlate_froms = entry['correlate_froms'] asfrom_froms = entry['asfrom_froms'] @@ -1588,7 +1646,6 @@ class SQLCompiler(Compiled): new_entry = { 'asfrom_froms': new_correlate_froms, - 'iswrapper': iswrapper, 'correlate_froms': all_correlate_froms, 'selectable': select, } @@ -1729,6 +1786,11 @@ class SQLCompiler(Compiled): ) def visit_insert(self, insert_stmt, **kw): + self.stack.append( + {'correlate_froms': set(), + "asfrom_froms": set(), + "selectable": insert_stmt}) + self.isinsert = True crud_params = crud._get_crud_params(self, insert_stmt, **kw) @@ -1812,6 +1874,8 @@ class SQLCompiler(Compiled): if self.returning and not self.returning_precedes_values: text += " " + returning_clause + self.stack.pop(-1) + return text def update_limit_clause(self, update_stmt): @@ -1847,7 +1911,6 @@ class SQLCompiler(Compiled): def visit_update(self, update_stmt, **kw): self.stack.append( {'correlate_froms': set([update_stmt.table]), - "iswrapper": False, "asfrom_froms": set([update_stmt.table]), "selectable": update_stmt}) @@ -1933,7 +1996,6 @@ class SQLCompiler(Compiled): def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), - "iswrapper": False, "asfrom_froms": set([delete_stmt.table]), "selectable": delete_stmt}) self.isdelete = True @@ -2078,7 +2140,9 @@ class DDLCompiler(Compiled): (table.description, column.name, ce.args[0]) )) - const = self.create_table_constraints(table) + const = self.create_table_constraints( + table, _include_foreign_key_constraints= + create.include_foreign_key_constraints) if const: text += ", \n\t" + const @@ -2102,7 +2166,9 @@ class DDLCompiler(Compiled): return text - def create_table_constraints(self, table): + def create_table_constraints( + self, table, + _include_foreign_key_constraints=None): # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) @@ -2110,8 +2176,15 @@ class DDLCompiler(Compiled): if table.primary_key: constraints.append(table.primary_key) + all_fkcs = table.foreign_key_constraints + if _include_foreign_key_constraints is not None: + omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints) + else: + omit_fkcs = set() + constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key]) + if c is not table.primary_key and + c not in omit_fkcs]) return ", \n\t".join( p for p in @@ -2206,15 +2279,26 @@ class DDLCompiler(Compiled): self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): + constraint = drop.element + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + else: + formatted_name = None + + if formatted_name is None: + raise exc.CompileError( + "Can't emit DROP CONSTRAINT for constraint %r; " + "it has no name" % drop.element) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), - self.preparer.format_constraint(drop.element), + formatted_name, drop.cascade and " CASCADE" or "" ) def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2231,7 +2315,8 @@ class DDLCompiler(Compiled): if isinstance(column.server_default.arg, util.string_types): return "'%s'" % column.server_default.arg else: - return self.sql_compiler.process(column.server_default.arg) + return self.sql_compiler.process( + column.server_default.arg, literal_binds=True) else: return None @@ -2278,14 +2363,14 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - remote_table = list(constraint._elements.values())[0].column.table + remote_table = list(constraint.elements)[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join(preparer.quote(f.parent.name) - for f in constraint._elements.values()), + for f in constraint.elements), self.define_constraint_remote_table( constraint, remote_table, preparer), ', '.join(preparer.quote(f.column.name) - for f in constraint._elements.values()) + for f in constraint.elements) ) text += self.define_constraint_match(constraint) text += self.define_constraint_cascades(constraint) @@ -2338,13 +2423,13 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): return "FLOAT" - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): return "REAL" - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -2355,7 +2440,7 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -2366,31 +2451,31 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): return "INTEGER" - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): return "SMALLINT" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return 'TIMESTAMP' - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME" - def visit_CLOB(self, type_): + def visit_CLOB(self, type_, **kw): return "CLOB" - def visit_NCLOB(self, type_): + def visit_NCLOB(self, type_, **kw): return "NCLOB" def _render_string_type(self, type_, name): @@ -2402,91 +2487,91 @@ class GenericTypeCompiler(TypeCompiler): text += ' COLLATE "%s"' % type_.collation return text - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._render_string_type(type_, "CHAR") - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._render_string_type(type_, "NCHAR") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._render_string_type(type_, "VARCHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._render_string_type(type_, "NVARCHAR") - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._render_string_type(type_, "TEXT") - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB" - def visit_BINARY(self, type_): + def visit_BINARY(self, type_, **kw): return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_): + def visit_BOOLEAN(self, type_, **kw): return "BOOLEAN" - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_): - return self.visit_BOOLEAN(type_) + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_): - return self.visit_TIME(type_) + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_): - return self.visit_DATETIME(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_): - return self.visit_DATE(type_) + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_BIGINT(type_) + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_): - return self.visit_SMALLINT(type_) + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_): - return self.visit_INTEGER(type_) + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_): - return self.visit_REAL(type_) + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_numeric(self, type_): - return self.visit_NUMERIC(type_) + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_): - return self.visit_VARCHAR(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_): - return self.visit_VARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_): - return self.visit_TEXT(type_) + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): - return self.visit_TEXT(type_) + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_): - return self.visit_VARCHAR(type_) + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_null(self, type_): + def visit_null(self, type_, **kw): raise exc.CompileError("Can't generate DDL for %r; " "did you forget to specify a " "type on this Column?" % type_) - def visit_type_decorator(self, type_): - return self.process(type_.type_engine(self.dialect)) + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_): - return type_.get_col_spec() + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) class IdentifierPreparer(object): diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 831d05be1..2e39f6b36 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -1,5 +1,5 @@ # sql/crud.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -116,11 +116,12 @@ def _get_crud_params(compiler, stmt, **kw): def _create_bind_param( - compiler, col, value, process=True, required=False, name=None): + compiler, col, value, process=True, + required=False, name=None): if name is None: name = col.key - bindparam = elements.BindParameter(name, value, - type_=col.type, required=required) + bindparam = elements.BindParameter( + name, value, type_=col.type, required=required) bindparam._is_crud = True if process: bindparam = bindparam._compiler_dispatch(compiler) @@ -300,13 +301,45 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.returning.append(c) else: values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) + else: compiler.returning.append(c) +def _create_prefetch_bind_param(compiler, c, process=True, name=None): + param = _create_bind_param(compiler, c, None, process=process, name=name) + compiler.prefetch.append(c) + return param + + +class _multiparam_column(elements.ColumnElement): + def __init__(self, original, index): + self.key = "%s_%d" % (original.key, index + 1) + self.original = original + self.default = original.default + + def __eq__(self, other): + return isinstance(other, _multiparam_column) and \ + other.key == self.key and \ + other.original == self.original + + +def _process_multiparam_default_bind(compiler, c, index, kw): + + if not c.default: + raise exc.CompileError( + "INSERT value for column %s is explicitly rendered as a bound" + "parameter in the VALUES clause; " + "a Python-side value or SQL expression is required" % c) + elif c.default.is_clause_element: + return compiler.process(c.default.arg.self_group(), **kw) + else: + col = _multiparam_column(c, index) + return _create_prefetch_bind_param(compiler, col) + + def _append_param_insert_pk(compiler, stmt, c, values, kw): if ( (c.default is not None and @@ -318,11 +351,9 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): preexecute_autoincrement_sequences) ): values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) - def _append_param_insert_hasdefault( compiler, stmt, c, implicit_return_defaults, values, kw): @@ -350,9 +381,8 @@ def _append_param_insert_hasdefault( compiler.postfetch.append(c) else: values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) def _append_param_insert_select_hasdefault( @@ -369,9 +399,8 @@ def _append_param_insert_select_hasdefault( values.append((c, proc)) else: values.append( - (c, _create_bind_param(compiler, c, None, process=False)) + (c, _create_prefetch_bind_param(compiler, c, process=False)) ) - compiler.prefetch.append(c) def _append_param_update( @@ -390,9 +419,8 @@ def _append_param_update( compiler.postfetch.append(c) else: values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) elif c.server_onupdate is not None: if implicit_return_defaults and \ c in implicit_return_defaults: @@ -445,12 +473,9 @@ def _get_multitable_params( compiler.postfetch.append(c) else: values.append( - (c, _create_bind_param( - compiler, c, None, name=_col_bind_name(c) - ) - ) + (c, _create_prefetch_bind_param( + compiler, c, name=_col_bind_name(c))) ) - compiler.prefetch.append(c) elif c.server_onupdate is not None: compiler.postfetch.append(c) @@ -469,7 +494,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): ) if elements._is_literal(row[c.key]) else compiler.process( row[c.key].self_group(), **kw)) - if c.key in row else param + if c.key in row else + _process_multiparam_default_bind(compiler, c, i, kw) ) for (c, param) in values_0 ] diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 1f2c448ea..3834f25f4 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -1,5 +1,5 @@ # sql/ddl.py -# Copyright (C) 2009-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2009-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -12,7 +12,6 @@ to invoke them for a create/drop call. from .. import util from .elements import ClauseElement -from .visitors import traverse from .base import Executable, _generative, SchemaVisitor, _bind_or_error from ..util import topological from .. import event @@ -370,7 +369,7 @@ class DDL(DDLElement): :class:`.DDLEvents` - :mod:`sqlalchemy.event` + :ref:`event_toplevel` """ @@ -464,19 +463,28 @@ class CreateTable(_CreateDropBase): __visit_name__ = "create_table" - def __init__(self, element, on=None, bind=None): + def __init__( + self, element, on=None, bind=None, + include_foreign_key_constraints=None): """Create a :class:`.CreateTable` construct. :param element: a :class:`.Table` that's the subject of the CREATE :param on: See the description for 'on' in :class:`.DDL`. :param bind: See the description for 'bind' in :class:`.DDL`. + :param include_foreign_key_constraints: optional sequence of + :class:`.ForeignKeyConstraint` objects that will be included + inline within the CREATE construct; if omitted, all foreign key + constraints that do not specify use_alter=True are included. + + .. versionadded:: 1.0.0 """ super(CreateTable, self).__init__(element, on=on, bind=bind) self.columns = [CreateColumn(column) for column in element.columns ] + self.include_foreign_key_constraints = include_foreign_key_constraints class _DropView(_CreateDropBase): @@ -696,8 +704,10 @@ class SchemaGenerator(DDLBase): tables = self.tables else: tables = list(metadata.tables.values()) - collection = [t for t in sort_tables(tables) - if self._can_create_table(t)] + + collection = sort_tables_and_constraints( + [t for t in tables if self._can_create_table(t)]) + seq_coll = [s for s in metadata._sequences.values() if s.column is None and self._can_create_sequence(s)] @@ -709,35 +719,62 @@ class SchemaGenerator(DDLBase): for seq in seq_coll: self.traverse_single(seq, create_ok=True) - for table in collection: - self.traverse_single(table, create_ok=True) + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, create_ok=True, + include_foreign_key_constraints=fkcs, + _is_metadata_operation=True) + else: + for fkc in fkcs: + self.traverse_single(fkc) metadata.dispatch.after_create(metadata, self.connection, tables=collection, checkfirst=self.checkfirst, _ddl_runner=self) - def visit_table(self, table, create_ok=False): + def visit_table( + self, table, create_ok=False, + include_foreign_key_constraints=None, + _is_metadata_operation=False): if not create_ok and not self._can_create_table(table): return - table.dispatch.before_create(table, self.connection, - checkfirst=self.checkfirst, - _ddl_runner=self) + table.dispatch.before_create( + table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation) for column in table.columns: if column.default is not None: self.traverse_single(column.default) - self.connection.execute(CreateTable(table)) + if not self.dialect.supports_alter: + # e.g., don't omit any foreign key constraints + include_foreign_key_constraints = None + + self.connection.execute( + CreateTable( + table, + include_foreign_key_constraints=include_foreign_key_constraints + )) if hasattr(table, 'indexes'): for index in table.indexes: self.traverse_single(index) - table.dispatch.after_create(table, self.connection, - checkfirst=self.checkfirst, - _ddl_runner=self) + table.dispatch.after_create( + table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + self.connection.execute(AddConstraint(constraint)) def visit_sequence(self, sequence, create_ok=False): if not create_ok and not self._can_create_sequence(sequence): @@ -765,11 +802,33 @@ class SchemaDropper(DDLBase): else: tables = list(metadata.tables.values()) - collection = [ - t - for t in reversed(sort_tables(tables)) - if self._can_drop_table(t) - ] + try: + collection = reversed( + sort_tables_and_constraints( + [t for t in tables if self._can_drop_table(t)], + filter_fn= + lambda constraint: True if not self.dialect.supports_alter + else False if constraint.name is None + else None + ) + ) + except exc.CircularDependencyError as err2: + util.raise_from_cause( + exc.CircularDependencyError( + err2.args[0], + err2.cycles, err2.edges, + msg="Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s. Please ensure " + "that the ForeignKey and ForeignKeyConstraint objects " + "involved in the cycle have " + "names so that they can be dropped using DROP CONSTRAINT." + % ( + ", ".join(sorted([t.fullname for t in err2.cycles])) + ) + + ) + ) seq_coll = [ s @@ -781,8 +840,13 @@ class SchemaDropper(DDLBase): metadata, self.connection, tables=collection, checkfirst=self.checkfirst, _ddl_runner=self) - for table in collection: - self.traverse_single(table, drop_ok=True) + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, drop_ok=True, _is_metadata_operation=True) + else: + for fkc in fkcs: + self.traverse_single(fkc) for seq in seq_coll: self.traverse_single(seq, drop_ok=True) @@ -812,13 +876,15 @@ class SchemaDropper(DDLBase): def visit_index(self, index): self.connection.execute(DropIndex(index)) - def visit_table(self, table, drop_ok=False): + def visit_table(self, table, drop_ok=False, _is_metadata_operation=False): if not drop_ok and not self._can_drop_table(table): return - table.dispatch.before_drop(table, self.connection, - checkfirst=self.checkfirst, - _ddl_runner=self) + table.dispatch.before_drop( + table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation) for column in table.columns: if column.default is not None: @@ -826,9 +892,16 @@ class SchemaDropper(DDLBase): self.connection.execute(DropTable(table)) - table.dispatch.after_drop(table, self.connection, - checkfirst=self.checkfirst, - _ddl_runner=self) + table.dispatch.after_drop( + table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + self.connection.execute(DropConstraint(constraint)) def visit_sequence(self, sequence, drop_ok=False): if not drop_ok and not self._can_drop_sequence(sequence): @@ -837,32 +910,159 @@ class SchemaDropper(DDLBase): def sort_tables(tables, skip_fn=None, extra_dependencies=None): - """sort a collection of Table objects in order of - their foreign-key dependency.""" + """sort a collection of :class:`.Table` objects based on dependency. - tables = list(tables) - tuples = [] - if extra_dependencies is not None: - tuples.extend(extra_dependencies) + This is a dependency-ordered sort which will emit :class:`.Table` + objects such that they will follow their dependent :class:`.Table` objects. + Tables are dependent on another based on the presence of + :class:`.ForeignKeyConstraint` objects as well as explicit dependencies + added by :meth:`.Table.add_is_dependent_on`. - def visit_foreign_key(fkey): - if fkey.use_alter: - return - elif skip_fn and skip_fn(fkey): - return - parent_table = fkey.column.table - if parent_table in tables: - child_table = fkey.parent.table - if parent_table is not child_table: - tuples.append((parent_table, child_table)) + .. warning:: + + The :func:`.sort_tables` function cannot by itself accommodate + automatic resolution of dependency cycles between tables, which + are usually caused by mutually dependent foreign key constraints. + To resolve these cycles, either the + :paramref:`.ForeignKeyConstraint.use_alter` parameter may be appled + to those constraints, or use the + :func:`.sql.sort_tables_and_constraints` function which will break + out foreign key constraints involved in cycles separately. + + :param tables: a sequence of :class:`.Table` objects. + :param skip_fn: optional callable which will be passed a + :class:`.ForeignKey` object; if it returns True, this + constraint will not be considered as a dependency. Note this is + **different** from the same parameter in + :func:`.sort_tables_and_constraints`, which is + instead passed the owning :class:`.ForeignKeyConstraint` object. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. seealso:: + + :func:`.sort_tables_and_constraints` + + :meth:`.MetaData.sorted_tables` - uses this function to sort + + + """ + + if skip_fn is not None: + def _skip_fn(fkc): + for fk in fkc.elements: + if skip_fn(fk): + return True + else: + return None + else: + _skip_fn = None + + return [ + t for (t, fkcs) in + sort_tables_and_constraints( + tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies) + if t is not None + ] + + +def sort_tables_and_constraints( + tables, filter_fn=None, extra_dependencies=None): + """sort a collection of :class:`.Table` / :class:`.ForeignKeyConstraint` + objects. + + This is a dependency-ordered sort which will emit tuples of + ``(Table, [ForeignKeyConstraint, ...])`` such that each + :class:`.Table` follows its dependent :class:`.Table` objects. + Remaining :class:`.ForeignKeyConstraint` objects that are separate due to + dependency rules not satisifed by the sort are emitted afterwards + as ``(None, [ForeignKeyConstraint ...])``. + + Tables are dependent on another based on the presence of + :class:`.ForeignKeyConstraint` objects, explicit dependencies + added by :meth:`.Table.add_is_dependent_on`, as well as dependencies + stated here using the :paramref:`~.sort_tables_and_constraints.skip_fn` + and/or :paramref:`~.sort_tables_and_constraints.extra_dependencies` + parameters. + + :param tables: a sequence of :class:`.Table` objects. + + :param filter_fn: optional callable which will be passed a + :class:`.ForeignKeyConstraint` object, and returns a value based on + whether this constraint should definitely be included or excluded as + an inline constraint, or neither. If it returns False, the constraint + will definitely be included as a dependency that cannot be subject + to ALTER; if True, it will **only** be included as an ALTER result at + the end. Returning None means the constraint is included in the + table-based result unless it is detected as part of a dependency cycle. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :func:`.sort_tables` + + + """ + + fixed_dependencies = set() + mutable_dependencies = set() + + if extra_dependencies is not None: + fixed_dependencies.update(extra_dependencies) + + remaining_fkcs = set() for table in tables: - traverse(table, - {'schema_visitor': True}, - {'foreign_key': visit_foreign_key}) + for fkc in table.foreign_key_constraints: + if fkc.use_alter is True: + remaining_fkcs.add(fkc) + continue + + if filter_fn: + filtered = filter_fn(fkc) + + if filtered is True: + remaining_fkcs.add(fkc) + continue - tuples.extend( - [parent, table] for parent in table._extra_dependencies + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.add((dependent_on, table)) + + fixed_dependencies.update( + (parent, table) for parent in table._extra_dependencies + ) + + try: + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), tables + ) + ) + except exc.CircularDependencyError as err: + for edge in err.edges: + if edge in mutable_dependencies: + table = edge[1] + can_remove = [ + fkc for fkc in table.foreign_key_constraints + if filter_fn is None or filter_fn(fkc) is not False] + remaining_fkcs.update(can_remove) + for fkc in can_remove: + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.discard((dependent_on, table)) + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), tables + ) ) - return list(topological.sort(tuples, tables)) + return [ + (table, table.foreign_key_constraints.difference(remaining_fkcs)) + for table in candidate_sort + ] + [(None, list(remaining_fkcs))] diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 4f53e2979..e77ad765c 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -1,5 +1,5 @@ # sql/default_comparator.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -9,8 +9,8 @@ """ from .. import exc, util -from . import operators from . import type_api +from . import operators from .elements import BindParameter, True_, False_, BinaryExpression, \ Null, _const_expr, _clause_element_as_expr, \ ClauseList, ColumnElement, TextClause, UnaryExpression, \ @@ -18,294 +18,270 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \ from .selectable import SelectBase, Alias, Selectable, ScalarSelect -class _DefaultColumnComparator(operators.ColumnOperators): - """Defines comparison and math operations. - - See :class:`.ColumnOperators` and :class:`.Operators` for descriptions - of all operations. - - """ - - @util.memoized_property - def type(self): - return self.expr.type - - def operate(self, op, *other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, self.expr, op, *(other + o[1:]), **kwargs) - - def reverse_operate(self, op, other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, self.expr, op, other, - reverse=True, *o[1:], **kwargs) - - def _adapt_expression(self, op, other_comparator): - """evaluate the return type of <self> <op> <othertype>, - and apply any adaptations to the given operator. - - This method determines the type of a resulting binary expression - given two source types and an operator. For example, two - :class:`.Column` objects, both of the type :class:`.Integer`, will - produce a :class:`.BinaryExpression` that also has the type - :class:`.Integer` when compared via the addition (``+``) operator. - However, using the addition operator with an :class:`.Integer` - and a :class:`.Date` object will produce a :class:`.Date`, assuming - "days delta" behavior by the database (in reality, most databases - other than Postgresql don't accept this particular operation). - - The method returns a tuple of the form <operator>, <type>. - The resulting operator and type will be those applied to the - resulting :class:`.BinaryExpression` as the final operator and the - right-hand side of the expression. - - Note that only a subset of operators make usage of - :meth:`._adapt_expression`, - including math operators and user-defined operators, but not - boolean comparison or special SQL keywords like MATCH or BETWEEN. - - """ - return op, other_comparator.type - - def _boolean_compare(self, expr, op, obj, negate=None, reverse=False, - _python_is_types=(util.NoneType, bool), - **kwargs): - - if isinstance(obj, _python_is_types + (Null, True_, False_)): - - # allow x ==/!= True/False to be treated as a literal. - # this comes out to "== / != true/false" or "1/0" if those - # constants aren't supported and works on all platforms - if op in (operators.eq, operators.ne) and \ - isinstance(obj, (bool, True_, False_)): - return BinaryExpression(expr, - _literal_as_text(obj), - op, - type_=type_api.BOOLEANTYPE, - negate=negate, modifiers=kwargs) - else: - # all other None/True/False uses IS, IS NOT - if op in (operators.eq, operators.is_): - return BinaryExpression(expr, _const_expr(obj), - operators.is_, - negate=operators.isnot) - elif op in (operators.ne, operators.isnot): - return BinaryExpression(expr, _const_expr(obj), - operators.isnot, - negate=operators.is_) - else: - raise exc.ArgumentError( - "Only '=', '!=', 'is_()', 'isnot()' operators can " - "be used with None/True/False") - else: - obj = self._check_literal(expr, op, obj) +def _boolean_compare(expr, op, obj, negate=None, reverse=False, + _python_is_types=(util.NoneType, bool), + result_type = None, + **kwargs): - if reverse: - return BinaryExpression(obj, - expr, - op, - type_=type_api.BOOLEANTYPE, - negate=negate, modifiers=kwargs) - else: + if result_type is None: + result_type = type_api.BOOLEANTYPE + + if isinstance(obj, _python_is_types + (Null, True_, False_)): + + # allow x ==/!= True/False to be treated as a literal. + # this comes out to "== / != true/false" or "1/0" if those + # constants aren't supported and works on all platforms + if op in (operators.eq, operators.ne) and \ + isinstance(obj, (bool, True_, False_)): return BinaryExpression(expr, - obj, + _literal_as_text(obj), op, - type_=type_api.BOOLEANTYPE, + type_=result_type, negate=negate, modifiers=kwargs) - - def _binary_operate(self, expr, op, obj, reverse=False, result_type=None, - **kw): - obj = self._check_literal(expr, op, obj) - - if reverse: - left, right = obj, expr - else: - left, right = expr, obj - - if result_type is None: - op, result_type = left.comparator._adapt_expression( - op, right.comparator) - - return BinaryExpression(left, right, op, type_=result_type) - - def _conjunction_operate(self, expr, op, other, **kw): - if op is operators.and_: - return and_(expr, other) - elif op is operators.or_: - return or_(expr, other) else: - raise NotImplementedError() - - def _scalar(self, expr, op, fn, **kw): - return fn(expr) - - def _in_impl(self, expr, op, seq_or_selectable, negate_op, **kw): - seq_or_selectable = _clause_element_as_expr(seq_or_selectable) - - if isinstance(seq_or_selectable, ScalarSelect): - return self._boolean_compare(expr, op, seq_or_selectable, - negate=negate_op) - elif isinstance(seq_or_selectable, SelectBase): - - # TODO: if we ever want to support (x, y, z) IN (select x, - # y, z from table), we would need a multi-column version of - # as_scalar() to produce a multi- column selectable that - # does not export itself as a FROM clause - - return self._boolean_compare( - expr, op, seq_or_selectable.as_scalar(), - negate=negate_op, **kw) - elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return self._boolean_compare(expr, op, seq_or_selectable, - negate=negate_op, **kw) - elif isinstance(seq_or_selectable, ClauseElement): - raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions ' - 'or a selectable: %r' % seq_or_selectable) - - # Handle non selectable arguments as sequences - args = [] - for o in seq_or_selectable: - if not _is_literal(o): - if not isinstance(o, operators.ColumnOperators): - raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions ' - 'or a selectable: %r' % o) - elif o is None: - o = Null() + # all other None/True/False uses IS, IS NOT + if op in (operators.eq, operators.is_): + return BinaryExpression(expr, _const_expr(obj), + operators.is_, + negate=operators.isnot) + elif op in (operators.ne, operators.isnot): + return BinaryExpression(expr, _const_expr(obj), + operators.isnot, + negate=operators.is_) else: - o = expr._bind_param(op, o) - args.append(o) - if len(args) == 0: - - # Special case handling for empty IN's, behave like - # comparison against zero row selectable. We use != to - # build the contradiction as it handles NULL values - # appropriately, i.e. "not (x IN ())" should not return NULL - # values for x. - - util.warn('The IN-predicate on "%s" was invoked with an ' - 'empty sequence. This results in a ' - 'contradiction, which nonetheless can be ' - 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.' % expr) - if op is operators.in_op: - return expr != expr - else: - return expr == expr - - return self._boolean_compare(expr, op, - ClauseList(*args).self_group(against=op), - negate=negate_op) - - def _unsupported_impl(self, expr, op, *arg, **kw): - raise NotImplementedError("Operator '%s' is not supported on " - "this expression" % op.__name__) - - def _inv_impl(self, expr, op, **kw): - """See :meth:`.ColumnOperators.__inv__`.""" - if hasattr(expr, 'negation_clause'): - return expr.negation_clause + raise exc.ArgumentError( + "Only '=', '!=', 'is_()', 'isnot()' operators can " + "be used with None/True/False") + else: + obj = _check_literal(expr, op, obj) + + if reverse: + return BinaryExpression(obj, + expr, + op, + type_=result_type, + negate=negate, modifiers=kwargs) + else: + return BinaryExpression(expr, + obj, + op, + type_=result_type, + negate=negate, modifiers=kwargs) + + +def _binary_operate(expr, op, obj, reverse=False, result_type=None, + **kw): + obj = _check_literal(expr, op, obj) + + if reverse: + left, right = obj, expr + else: + left, right = expr, obj + + if result_type is None: + op, result_type = left.comparator._adapt_expression( + op, right.comparator) + + return BinaryExpression( + left, right, op, type_=result_type, modifiers=kw) + + +def _conjunction_operate(expr, op, other, **kw): + if op is operators.and_: + return and_(expr, other) + elif op is operators.or_: + return or_(expr, other) + else: + raise NotImplementedError() + + +def _scalar(expr, op, fn, **kw): + return fn(expr) + + +def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): + seq_or_selectable = _clause_element_as_expr(seq_or_selectable) + + if isinstance(seq_or_selectable, ScalarSelect): + return _boolean_compare(expr, op, seq_or_selectable, + negate=negate_op) + elif isinstance(seq_or_selectable, SelectBase): + + # TODO: if we ever want to support (x, y, z) IN (select x, + # y, z from table), we would need a multi-column version of + # as_scalar() to produce a multi- column selectable that + # does not export itself as a FROM clause + + return _boolean_compare( + expr, op, seq_or_selectable.as_scalar(), + negate=negate_op, **kw) + elif isinstance(seq_or_selectable, (Selectable, TextClause)): + return _boolean_compare(expr, op, seq_or_selectable, + negate=negate_op, **kw) + elif isinstance(seq_or_selectable, ClauseElement): + raise exc.InvalidRequestError( + 'in_() accepts' + ' either a list of expressions ' + 'or a selectable: %r' % seq_or_selectable) + + # Handle non selectable arguments as sequences + args = [] + for o in seq_or_selectable: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + raise exc.InvalidRequestError( + 'in_() accepts' + ' either a list of expressions ' + 'or a selectable: %r' % o) + elif o is None: + o = Null() else: - return expr._negate() - - def _neg_impl(self, expr, op, **kw): - """See :meth:`.ColumnOperators.__neg__`.""" - return UnaryExpression(expr, operator=operators.neg) - - def _match_impl(self, expr, op, other, **kw): - """See :meth:`.ColumnOperators.match`.""" - return self._boolean_compare( - expr, operators.match_op, - self._check_literal( - expr, operators.match_op, other), - **kw) - - def _distinct_impl(self, expr, op, **kw): - """See :meth:`.ColumnOperators.distinct`.""" - return UnaryExpression(expr, operator=operators.distinct_op, - type_=expr.type) - - def _between_impl(self, expr, op, cleft, cright, **kw): - """See :meth:`.ColumnOperators.between`.""" - return BinaryExpression( - expr, - ClauseList( - self._check_literal(expr, operators.and_, cleft), - self._check_literal(expr, operators.and_, cright), - operator=operators.and_, - group=False, group_contents=False), - op, - negate=operators.notbetween_op - if op is operators.between_op - else operators.between_op, - modifiers=kw) - - def _collate_impl(self, expr, op, other, **kw): - return collate(expr, other) - - # a mapping of operators with the method they use, along with - # their negated operator for comparison operators - operators = { - "and_": (_conjunction_operate,), - "or_": (_conjunction_operate,), - "inv": (_inv_impl,), - "add": (_binary_operate,), - "mul": (_binary_operate,), - "sub": (_binary_operate,), - "div": (_binary_operate,), - "mod": (_binary_operate,), - "truediv": (_binary_operate,), - "custom_op": (_binary_operate,), - "concat_op": (_binary_operate,), - "lt": (_boolean_compare, operators.ge), - "le": (_boolean_compare, operators.gt), - "ne": (_boolean_compare, operators.eq), - "gt": (_boolean_compare, operators.le), - "ge": (_boolean_compare, operators.lt), - "eq": (_boolean_compare, operators.ne), - "like_op": (_boolean_compare, operators.notlike_op), - "ilike_op": (_boolean_compare, operators.notilike_op), - "notlike_op": (_boolean_compare, operators.like_op), - "notilike_op": (_boolean_compare, operators.ilike_op), - "contains_op": (_boolean_compare, operators.notcontains_op), - "startswith_op": (_boolean_compare, operators.notstartswith_op), - "endswith_op": (_boolean_compare, operators.notendswith_op), - "desc_op": (_scalar, UnaryExpression._create_desc), - "asc_op": (_scalar, UnaryExpression._create_asc), - "nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst), - "nullslast_op": (_scalar, UnaryExpression._create_nullslast), - "in_op": (_in_impl, operators.notin_op), - "notin_op": (_in_impl, operators.in_op), - "is_": (_boolean_compare, operators.is_), - "isnot": (_boolean_compare, operators.isnot), - "collate": (_collate_impl,), - "match_op": (_match_impl,), - "distinct_op": (_distinct_impl,), - "between_op": (_between_impl, ), - "notbetween_op": (_between_impl, ), - "neg": (_neg_impl,), - "getitem": (_unsupported_impl,), - "lshift": (_unsupported_impl,), - "rshift": (_unsupported_impl,), - } - - def _check_literal(self, expr, operator, other): - if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and \ - other.type._isnull: - other = other._clone() - other.type = expr.type - return other - elif hasattr(other, '__clause_element__'): - other = other.__clause_element__() - elif isinstance(other, type_api.TypeEngine.Comparator): - other = other.expr - - if isinstance(other, (SelectBase, Alias)): - return other.as_scalar() - elif not isinstance(other, (ColumnElement, TextClause)): - return expr._bind_param(operator, other) + o = expr._bind_param(op, o) + args.append(o) + if len(args) == 0: + + # Special case handling for empty IN's, behave like + # comparison against zero row selectable. We use != to + # build the contradiction as it handles NULL values + # appropriately, i.e. "not (x IN ())" should not return NULL + # values for x. + + util.warn('The IN-predicate on "%s" was invoked with an ' + 'empty sequence. This results in a ' + 'contradiction, which nonetheless can be ' + 'expensive to evaluate. Consider alternative ' + 'strategies for improved performance.' % expr) + if op is operators.in_op: + return expr != expr else: - return other + return expr == expr + + return _boolean_compare(expr, op, + ClauseList(*args).self_group(against=op), + negate=negate_op) + + +def _unsupported_impl(expr, op, *arg, **kw): + raise NotImplementedError("Operator '%s' is not supported on " + "this expression" % op.__name__) + + +def _inv_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.__inv__`.""" + if hasattr(expr, 'negation_clause'): + return expr.negation_clause + else: + return expr._negate() + + +def _neg_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.__neg__`.""" + return UnaryExpression(expr, operator=operators.neg) + + +def _match_impl(expr, op, other, **kw): + """See :meth:`.ColumnOperators.match`.""" + + return _boolean_compare( + expr, operators.match_op, + _check_literal( + expr, operators.match_op, other), + result_type=type_api.MATCHTYPE, + negate=operators.notmatch_op + if op is operators.match_op else operators.match_op, + **kw + ) + + +def _distinct_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.distinct`.""" + return UnaryExpression(expr, operator=operators.distinct_op, + type_=expr.type) + + +def _between_impl(expr, op, cleft, cright, **kw): + """See :meth:`.ColumnOperators.between`.""" + return BinaryExpression( + expr, + ClauseList( + _check_literal(expr, operators.and_, cleft), + _check_literal(expr, operators.and_, cright), + operator=operators.and_, + group=False, group_contents=False), + op, + negate=operators.notbetween_op + if op is operators.between_op + else operators.between_op, + modifiers=kw) + + +def _collate_impl(expr, op, other, **kw): + return collate(expr, other) + +# a mapping of operators with the method they use, along with +# their negated operator for comparison operators +operator_lookup = { + "and_": (_conjunction_operate,), + "or_": (_conjunction_operate,), + "inv": (_inv_impl,), + "add": (_binary_operate,), + "mul": (_binary_operate,), + "sub": (_binary_operate,), + "div": (_binary_operate,), + "mod": (_binary_operate,), + "truediv": (_binary_operate,), + "custom_op": (_binary_operate,), + "concat_op": (_binary_operate,), + "lt": (_boolean_compare, operators.ge), + "le": (_boolean_compare, operators.gt), + "ne": (_boolean_compare, operators.eq), + "gt": (_boolean_compare, operators.le), + "ge": (_boolean_compare, operators.lt), + "eq": (_boolean_compare, operators.ne), + "like_op": (_boolean_compare, operators.notlike_op), + "ilike_op": (_boolean_compare, operators.notilike_op), + "notlike_op": (_boolean_compare, operators.like_op), + "notilike_op": (_boolean_compare, operators.ilike_op), + "contains_op": (_boolean_compare, operators.notcontains_op), + "startswith_op": (_boolean_compare, operators.notstartswith_op), + "endswith_op": (_boolean_compare, operators.notendswith_op), + "desc_op": (_scalar, UnaryExpression._create_desc), + "asc_op": (_scalar, UnaryExpression._create_asc), + "nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst), + "nullslast_op": (_scalar, UnaryExpression._create_nullslast), + "in_op": (_in_impl, operators.notin_op), + "notin_op": (_in_impl, operators.in_op), + "is_": (_boolean_compare, operators.is_), + "isnot": (_boolean_compare, operators.isnot), + "collate": (_collate_impl,), + "match_op": (_match_impl,), + "notmatch_op": (_match_impl,), + "distinct_op": (_distinct_impl,), + "between_op": (_between_impl, ), + "notbetween_op": (_between_impl, ), + "neg": (_neg_impl,), + "getitem": (_unsupported_impl,), + "lshift": (_unsupported_impl,), + "rshift": (_unsupported_impl,), +} + + +def _check_literal(expr, operator, other): + if isinstance(other, (ColumnElement, TextClause)): + if isinstance(other, BindParameter) and \ + other.type._isnull: + other = other._clone() + other.type = expr.type + return other + elif hasattr(other, '__clause_element__'): + other = other.__clause_element__() + elif isinstance(other, type_api.TypeEngine.Comparator): + other = other.expr + + if isinstance(other, (SelectBase, Alias)): + return other.as_scalar() + elif not isinstance(other, (ColumnElement, TextClause)): + return expr._bind_param(operator, other) + else: + return other + diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 9f2ce7ce3..6a4768fa1 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -1,5 +1,5 @@ # sql/dml.py -# Copyright (C) 2009-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2009-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -277,6 +277,12 @@ class ValuesBase(UpdateBase): deals with an arbitrary number of rows, so the :attr:`.ResultProxy.inserted_primary_key` accessor does not apply. + .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports + columns with Python side default values and callables in the + same way as that of an "executemany" style of invocation; the + callable is invoked for each row. See :ref:`bug_3288` + for other details. + .. seealso:: :ref:`inserts_and_updates` - SQL Expression @@ -387,7 +393,7 @@ class ValuesBase(UpdateBase): :func:`.mapper`. :param cols: optional list of column key names or :class:`.Column` - objects. If omitted, all column expressions evaulated on the server + objects. If omitted, all column expressions evaluated on the server are added to the returning list. .. versionadded:: 0.9.0 diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 444273e67..ca8ec1f55 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1,5 +1,5 @@ # sql/elements.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -861,6 +861,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): expressions and function calls. """ + while self._is_clone_of is not None: + self = self._is_clone_of + return _anonymous_label( '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon')) ) @@ -1089,7 +1092,7 @@ class BindParameter(ColumnElement): """ if isinstance(key, ColumnClause): type_ = key.type - key = key.name + key = key.key if required is NO_ARG: required = (value is NO_ARG and callable_ is None) if value is NO_ARG: @@ -1276,7 +1279,7 @@ class TextClause(Executable, ClauseElement): E.g.:: - fom sqlalchemy import text + from sqlalchemy import text t = text("SELECT * FROM users") result = connection.execute(t) @@ -1617,10 +1620,10 @@ class Null(ColumnElement): return type_api.NULLTYPE @classmethod - def _singleton(cls): + def _instance(cls): """Return a constant :class:`.Null` construct.""" - return NULL + return Null() def compare(self, other): return isinstance(other, Null) @@ -1641,11 +1644,11 @@ class False_(ColumnElement): return type_api.BOOLEANTYPE def _negate(self): - return TRUE + return True_() @classmethod - def _singleton(cls): - """Return a constant :class:`.False_` construct. + def _instance(cls): + """Return a :class:`.False_` construct. E.g.:: @@ -1679,7 +1682,7 @@ class False_(ColumnElement): """ - return FALSE + return False_() def compare(self, other): return isinstance(other, False_) @@ -1700,17 +1703,17 @@ class True_(ColumnElement): return type_api.BOOLEANTYPE def _negate(self): - return FALSE + return False_() @classmethod def _ifnone(cls, other): if other is None: - return cls._singleton() + return cls._instance() else: return other @classmethod - def _singleton(cls): + def _instance(cls): """Return a constant :class:`.True_` construct. E.g.:: @@ -1745,15 +1748,11 @@ class True_(ColumnElement): """ - return TRUE + return True_() def compare(self, other): return isinstance(other, True_) -NULL = Null() -FALSE = False_() -TRUE = True_() - class ClauseList(ClauseElement): """Describe a list of clauses, separated by an operator. @@ -2147,7 +2146,7 @@ class Case(ColumnElement): result of the ``CASE`` construct if all expressions within :paramref:`.case.whens` evaluate to false. When omitted, most databases will produce a result of NULL if none of the "when" - expressions evaulate to true. + expressions evaluate to true. """ @@ -2764,7 +2763,7 @@ class BinaryExpression(ColumnElement): self.right, self.negate, negate=self.operator, - type_=type_api.BOOLEANTYPE, + type_=self.type, modifiers=self.modifiers) else: return super(BinaryExpression, self)._negate() @@ -2783,6 +2782,10 @@ class Grouping(ColumnElement): return self @property + def _key_label(self): + return self._label + + @property def _label(self): return getattr(self.element, '_label', None) or self.anon_label @@ -3037,10 +3040,12 @@ class Label(ColumnElement): if name: self.name = name + self._resolve_label = self.name else: self.name = _anonymous_label( '%%(%d %s)s' % (id(self), getattr(element, 'name', 'anon')) ) + self.key = self._label = self._key_label = self.name self._element = element self._type = type_ @@ -3091,7 +3096,7 @@ class Label(ColumnElement): self.element = clone(self.element, **kw) self.__dict__.pop('_allow_label_resolve', None) if anonymize_labels: - self.name = _anonymous_label( + self.name = self._resolve_label = _anonymous_label( '%%(%d %s)s' % ( id(self), getattr(self.element, 'name', 'anon')) ) @@ -3332,7 +3337,7 @@ class ColumnClause(Immutable, ColumnElement): return name def _bind_param(self, operator, obj): - return BindParameter(self.name, obj, + return BindParameter(self.key, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) @@ -3384,7 +3389,7 @@ class ReleaseSavepointClause(_IdentifiedClause): __visit_name__ = 'release_savepoint' -class quoted_name(util.text_type): +class quoted_name(util.MemoizedSlots, util.text_type): """Represent a SQL identifier combined with quoting preferences. :class:`.quoted_name` is a Python unicode/str subclass which @@ -3428,6 +3433,8 @@ class quoted_name(util.text_type): """ + __slots__ = 'quote', 'lower', 'upper' + def __new__(cls, value, quote): if value is None: return None @@ -3447,15 +3454,13 @@ class quoted_name(util.text_type): def __reduce__(self): return quoted_name, (util.text_type(self), self.quote) - @util.memoized_instancemethod - def lower(self): + def _memoized_method_lower(self): if self.quote: return self else: return util.text_type(self).lower() - @util.memoized_instancemethod - def upper(self): + def _memoized_method_upper(self): if self.quote: return self else: @@ -3472,6 +3477,8 @@ class _truncated_label(quoted_name): """A unicode subclass used to identify symbolic " "names that may require truncation.""" + __slots__ = () + def __new__(cls, value, quote=None): quote = getattr(value, "quote", quote) # return super(_truncated_label, cls).__new__(cls, value, quote, True) @@ -3528,6 +3535,7 @@ class conv(_truncated_label): :ref:`constraint_naming_conventions` """ + __slots__ = () class _defer_name(_truncated_label): @@ -3535,6 +3543,8 @@ class _defer_name(_truncated_label): generation. """ + __slots__ = () + def __new__(cls, value): if value is None: return _NONE_NAME @@ -3549,6 +3559,7 @@ class _defer_name(_truncated_label): class _defer_none_name(_defer_name): """indicate a 'deferred' name that was ultimately the value None.""" + __slots__ = () _NONE_NAME = _defer_none_name("_unnamed_") @@ -3563,6 +3574,8 @@ class _anonymous_label(_truncated_label): """A unicode subclass used to identify anonymously generated names.""" + __slots__ = () + def __add__(self, other): return _anonymous_label( quoted_name( @@ -3729,7 +3742,8 @@ def _literal_as_text(element, warn=False): return _const_expr(element) else: raise exc.ArgumentError( - "SQL expression object or string expected." + "SQL expression object or string expected, got object of type %r " + "instead" % type(element) ) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2e10b7370..74b827d7e 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1,5 +1,5 @@ # sql/expression.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -47,7 +47,7 @@ from .base import ColumnCollection, Generative, Executable, \ from .selectable import Alias, Join, Select, Selectable, TableClause, \ CompoundSelect, CTE, FromClause, FromGrouping, SelectBase, \ alias, GenerativeSelect, \ - subquery, HasPrefixes, Exists, ScalarSelect, TextAsFrom + subquery, HasPrefixes, HasSuffixes, Exists, ScalarSelect, TextAsFrom from .dml import Insert, Update, Delete, UpdateBase, ValuesBase @@ -89,9 +89,9 @@ asc = public_factory(UnaryExpression._create_asc, ".expression.asc") desc = public_factory(UnaryExpression._create_desc, ".expression.desc") distinct = public_factory( UnaryExpression._create_distinct, ".expression.distinct") -true = public_factory(True_._singleton, ".expression.true") -false = public_factory(False_._singleton, ".expression.false") -null = public_factory(Null._singleton, ".expression.null") +true = public_factory(True_._instance, ".expression.true") +false = public_factory(False_._instance, ".expression.false") +null = public_factory(Null._instance, ".expression.null") join = public_factory(Join._create_join, ".expression.join") outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin") insert = public_factory(Insert, ".expression.insert") diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 9280c7d60..538a2c549 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,5 +1,5 @@ # sql/functions.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index 9e57418b0..bc13835ed 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -1,5 +1,5 @@ # sqlalchemy/naming.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -113,10 +113,12 @@ def _constraint_name_for_table(const, table): if isinstance(const.name, conv): return const.name - elif convention is not None and ( - const.name is None or not isinstance(const.name, conv) and - "constraint_name" in convention - ): + elif convention is not None and \ + not isinstance(const.name, conv) and \ + ( + const.name is None or + "constraint_name" in convention or + isinstance(const.name, _defer_name)): return conv( convention % ConventionDict(const, table, metadata.naming_convention) diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 945356328..51f162c98 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1,5 +1,5 @@ # sql/operators.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -38,6 +38,7 @@ class Operators(object): :class:`.ColumnOperators`. """ + __slots__ = () def __and__(self, other): """Implement the ``&`` operator. @@ -137,7 +138,7 @@ class Operators(object): .. versionadded:: 0.8 - added the 'precedence' argument. :param is_comparison: if True, the operator will be considered as a - "comparison" operator, that is which evaulates to a boolean + "comparison" operator, that is which evaluates to a boolean true/false value, like ``==``, ``>``, etc. This flag should be set so that ORM relationships can establish that the operator is a comparison operator when used in a custom join condition. @@ -267,6 +268,8 @@ class ColumnOperators(Operators): """ + __slots__ = () + timetuple = None """Hack, allows datetime objects to be compared on the LHS.""" @@ -529,8 +532,10 @@ class ColumnOperators(Operators): * Postgresql - renders ``x @@ to_tsquery(y)`` * MySQL - renders ``MATCH (x) AGAINST (y IN BOOLEAN MODE)`` * Oracle - renders ``CONTAINS(x, y)`` - * other backends may provide special implementations; - some backends such as SQLite have no support. + * other backends may provide special implementations. + * Backends without any special implementation will emit + the operator as "MATCH". This is compatible with SQlite, for + example. """ return self.operate(match_op, other, **kwargs) @@ -767,6 +772,10 @@ def match_op(a, b, **kw): return a.match(b, **kw) +def notmatch_op(a, b, **kw): + return a.notmatch(b, **kw) + + def comma_op(a, b): raise NotImplementedError() @@ -834,6 +843,7 @@ _PRECEDENCE = { concat_op: 6, match_op: 6, + notmatch_op: 6, ilike_op: 6, notilike_op: 6, diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index ef5d79a48..3aeba9804 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1,5 +1,5 @@ # sql/schema.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -516,6 +516,19 @@ class Table(DialectKWArgs, SchemaItem, TableClause): """ return sorted(self.constraints, key=lambda c: c._creation_order) + @property + def foreign_key_constraints(self): + """:class:`.ForeignKeyConstraint` objects referred to by this + :class:`.Table`. + + This list is produced from the collection of :class:`.ForeignKey` + objects currently associated. + + .. versionadded:: 1.0.0 + + """ + return set(fkc.constraint for fkc in self.foreign_keys) + def _init_existing(self, *args, **kwargs): autoload_with = kwargs.pop('autoload_with', None) autoload = kwargs.pop('autoload', autoload_with is not None) @@ -728,7 +741,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): checkfirst=checkfirst) def tometadata(self, metadata, schema=RETAIN_SCHEMA, - referred_schema_fn=None): + referred_schema_fn=None, name=None): """Return a copy of this :class:`.Table` associated with a different :class:`.MetaData`. @@ -785,13 +798,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): .. versionadded:: 0.9.2 - """ + :param name: optional string name indicating the target table name. + If not specified or None, the table name is retained. This allows + a :class:`.Table` to be copied to the same :class:`.MetaData` target + with a new name. + + .. versionadded:: 1.0.0 + """ + if name is None: + name = self.name if schema is RETAIN_SCHEMA: schema = self.schema elif schema is None: schema = metadata.schema - key = _get_table_key(self.name, schema) + key = _get_table_key(name, schema) if key in metadata.tables: util.warn("Table '%s' already exists within the given " "MetaData - not copying." % self.description) @@ -801,7 +822,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): for c in self.columns: args.append(c.copy(schema=schema)) table = Table( - self.name, metadata, schema=schema, + name, metadata, schema=schema, *args, **self.kwargs ) for c in self.constraints: @@ -816,7 +837,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): table.append_constraint( c.copy(schema=fk_constraint_schema, target_table=table)) - else: + elif not c._type_bound: table.append_constraint( c.copy(schema=schema, target_table=table)) for index in self.indexes: @@ -1267,10 +1288,18 @@ class Column(SchemaItem, ColumnClause): "Index object external to the Table.") table.append_constraint(UniqueConstraint(self.key)) - fk_key = (table.key, self.key) - if fk_key in self.table.metadata._fk_memos: - for fk in self.table.metadata._fk_memos[fk_key]: - fk._set_remote_table(table) + self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table)) + + def _setup_on_memoized_fks(self, fn): + fk_keys = [ + ((self.table.key, self.key), False), + ((self.table.key, self.name), True), + ] + for fk_key, link_to_name in fk_keys: + if fk_key in self.table.metadata._fk_memos: + for fk in self.table.metadata._fk_memos[fk_key]: + if fk.link_to_name is link_to_name: + fn(fk) def _on_table_attach(self, fn): if self.table is not None: @@ -1287,7 +1316,7 @@ class Column(SchemaItem, ColumnClause): # Constraint objects plus non-constraint-bound ForeignKey objects args = \ - [c.copy(**kw) for c in self.constraints] + \ + [c.copy(**kw) for c in self.constraints if not c._type_bound] + \ [c.copy(**kw) for c in self.foreign_keys if not c.constraint] type_ = self.type @@ -1455,7 +1484,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): :param use_alter: passed to the underlying :class:`.ForeignKeyConstraint` to indicate the constraint should be generated/dropped externally from the CREATE TABLE/ DROP TABLE - statement. See that classes' constructor for details. + statement. See :paramref:`.ForeignKeyConstraint.use_alter` + for further description. + + .. seealso:: + + :paramref:`.ForeignKeyConstraint.use_alter` + + :ref:`use_alter` :param match: Optional string. If set, emit MATCH <value> when issuing DDL for this constraint. Typical values include SIMPLE, PARTIAL @@ -1549,7 +1585,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): ) return self._schema_item_copy(fk) - def _get_colspec(self, schema=None): + def _get_colspec(self, schema=None, table_name=None): """Return a string based 'column specification' for this :class:`.ForeignKey`. @@ -1559,7 +1595,15 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ if schema: _schema, tname, colname = self._column_tokens + if table_name is not None: + tname = table_name return "%s.%s.%s" % (schema, tname, colname) + elif table_name: + schema, tname, colname = self._column_tokens + if schema: + return "%s.%s.%s" % (schema, table_name, colname) + else: + return "%s.%s" % (table_name, colname) elif self._table_column is not None: return "%s.%s" % ( self._table_column.table.fullname, self._table_column.key) @@ -1704,11 +1748,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): # super-edgy case, if other FKs point to our column, # they'd get the type propagated out also. if isinstance(self.parent.table, Table): - fk_key = (self.parent.table.key, self.parent.key) - if fk_key in self.parent.table.metadata._fk_memos: - for fk in self.parent.table.metadata._fk_memos[fk_key]: - if fk.parent.type._isnull: - fk.parent.type = column.type + + def set_type(fk): + if fk.parent.type._isnull: + fk.parent.type = column.type + self.parent._setup_on_memoized_fks(set_type) self.column = column @@ -1788,7 +1832,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): match=self.match, **self._unvalidated_dialect_kw ) - self.constraint._elements[self.parent] = self + self.constraint._append_element(column, self) self.constraint._set_parent_with_dispatch(table) table.foreign_keys.add(self) @@ -2238,7 +2282,7 @@ class Constraint(DialectKWArgs, SchemaItem): __visit_name__ = 'constraint' def __init__(self, name=None, deferrable=None, initially=None, - _create_rule=None, info=None, + _create_rule=None, info=None, _type_bound=False, **dialect_kw): """Create a SQL constraint. @@ -2288,6 +2332,7 @@ class Constraint(DialectKWArgs, SchemaItem): if info: self.info = info self._create_rule = _create_rule + self._type_bound = _type_bound util.set_creation_order(self) self._validate_dialect_kwargs(dialect_kw) @@ -2328,14 +2373,61 @@ def _to_schema_column_or_string(element): class ColumnCollectionMixin(object): - def __init__(self, *columns): + + columns = None + """A :class:`.ColumnCollection` of :class:`.Column` objects. + + This collection represents the columns which are referred to by + this object. + + """ + + _allow_multiple_tables = False + + def __init__(self, *columns, **kw): + _autoattach = kw.pop('_autoattach', True) self.columns = ColumnCollection() self._pending_colargs = [_to_schema_column_or_string(c) for c in columns] - if self._pending_colargs and \ - isinstance(self._pending_colargs[0], Column) and \ - isinstance(self._pending_colargs[0].table, Table): - self._set_parent_with_dispatch(self._pending_colargs[0].table) + if _autoattach and self._pending_colargs: + self._check_attach() + + def _check_attach(self, evt=False): + col_objs = [ + c for c in self._pending_colargs + if isinstance(c, Column) + ] + cols_w_table = [ + c for c in col_objs if isinstance(c.table, Table) + ] + cols_wo_table = set(col_objs).difference(cols_w_table) + + if cols_wo_table: + assert not evt, "Should not reach here on event call" + + def _col_attached(column, table): + cols_wo_table.discard(column) + if not cols_wo_table: + self._check_attach(evt=True) + self._cols_wo_table = cols_wo_table + for col in cols_wo_table: + col._on_table_attach(_col_attached) + return + + columns = cols_w_table + + tables = set([c.table for c in columns]) + if len(tables) == 1: + self._set_parent_with_dispatch(tables.pop()) + elif len(tables) > 1 and not self._allow_multiple_tables: + table = columns[0].table + others = [c for c in columns[1:] if c.table is not table] + if others: + raise exc.ArgumentError( + "Column(s) %s are not part of table '%s'." % + (", ".join("'%s'" % c for c in others), + table.description) + ) def _set_parent(self, table): for col in self._pending_colargs: @@ -2367,8 +2459,9 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): arguments are propagated to the :class:`.Constraint` superclass. """ + _autoattach = kw.pop('_autoattach', True) Constraint.__init__(self, **kw) - ColumnCollectionMixin.__init__(self, *columns) + ColumnCollectionMixin.__init__(self, *columns, _autoattach=_autoattach) def _set_parent(self, table): Constraint._set_parent(self, table) @@ -2383,6 +2476,13 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): return self._schema_item_copy(c) def contains_column(self, col): + """Return True if this constraint contains the given column. + + Note that this object also contains an attribute ``.columns`` + which is a :class:`.ColumnCollection` of :class:`.Column` objects. + + """ + return self.columns.contains_column(col) def __iter__(self): @@ -2396,15 +2496,17 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): return len(self.columns._data) -class CheckConstraint(Constraint): +class CheckConstraint(ColumnCollectionConstraint): """A table- or column-level CHECK constraint. Can be included in the definition of a Table or Column. """ + _allow_multiple_tables = True + def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None, info=None, _create_rule=None, - _autoattach=True): + _autoattach=True, _type_bound=False): """Construct a CHECK constraint. :param sqltext: @@ -2433,18 +2535,19 @@ class CheckConstraint(Constraint): """ - super(CheckConstraint, self).\ - __init__(name, deferrable, initially, _create_rule, info=info) self.sqltext = _literal_as_text(sqltext, warn=False) + + columns = [] + visitors.traverse(self.sqltext, {}, {'column': columns.append}) + + super(CheckConstraint, self).\ + __init__( + name=name, deferrable=deferrable, + initially=initially, _create_rule=_create_rule, info=info, + _type_bound=_type_bound, _autoattach=_autoattach, + *columns) if table is not None: self._set_parent_with_dispatch(table) - elif _autoattach: - cols = _find_columns(self.sqltext) - tables = set([c.table for c in cols - if isinstance(c.table, Table)]) - if len(tables) == 1: - self._set_parent_with_dispatch( - tables.pop()) def __visit_name__(self): if isinstance(self.parent, Table): @@ -2469,11 +2572,12 @@ class CheckConstraint(Constraint): deferrable=self.deferrable, _create_rule=self._create_rule, table=target_table, - _autoattach=False) + _autoattach=False, + _type_bound=self._type_bound) return self._schema_item_copy(c) -class ForeignKeyConstraint(Constraint): +class ForeignKeyConstraint(ColumnCollectionConstraint): """A table-level FOREIGN KEY constraint. Defines a single column or composite FOREIGN KEY ... REFERENCES @@ -2525,11 +2629,23 @@ class ForeignKeyConstraint(Constraint): part of the CREATE TABLE definition. Instead, generate it via an ALTER TABLE statement issued after the full collection of tables have been created, and drop it via an ALTER TABLE statement before - the full collection of tables are dropped. This is shorthand for the - usage of :class:`.AddConstraint` and :class:`.DropConstraint` - applied as "after-create" and "before-drop" events on the MetaData - object. This is normally used to generate/drop constraints on - objects that are mutually dependent on each other. + the full collection of tables are dropped. + + The use of :paramref:`.ForeignKeyConstraint.use_alter` is + particularly geared towards the case where two or more tables + are established within a mutually-dependent foreign key constraint + relationship; however, the :meth:`.MetaData.create_all` and + :meth:`.MetaData.drop_all` methods will perform this resolution + automatically, so the flag is normally not needed. + + .. versionchanged:: 1.0.0 Automatic resolution of foreign key + cycles has been added, removing the need to use the + :paramref:`.ForeignKeyConstraint.use_alter` in typical use + cases. + + .. seealso:: + + :ref:`use_alter` :param match: Optional string. If set, emit MATCH <value> when issuing DDL for this constraint. Typical values include SIMPLE, PARTIAL @@ -2548,25 +2664,22 @@ class ForeignKeyConstraint(Constraint): .. versionadded:: 0.9.2 """ - super(ForeignKeyConstraint, self).\ - __init__(name, deferrable, initially, info=info, **dialect_kw) + Constraint.__init__( + self, name=name, deferrable=deferrable, initially=initially, + info=info, **dialect_kw) self.onupdate = onupdate self.ondelete = ondelete self.link_to_name = link_to_name - if self.name is None and use_alter: - raise exc.ArgumentError("Alterable Constraint requires a name") self.use_alter = use_alter self.match = match - self._elements = util.OrderedDict() - # standalone ForeignKeyConstraint - create # associated ForeignKey objects which will be applied to hosted # Column objects (in col.foreign_keys), either now or when attached # to the Table for string-specified names - for col, refcol in zip(columns, refcolumns): - self._elements[col] = ForeignKey( + self.elements = [ + ForeignKey( refcol, _constraint=self, name=self.name, @@ -2578,25 +2691,50 @@ class ForeignKeyConstraint(Constraint): deferrable=self.deferrable, initially=self.initially, **self.dialect_kwargs - ) + ) for refcol in refcolumns + ] + ColumnCollectionMixin.__init__(self, *columns) if table is not None: + if hasattr(self, "parent"): + assert table is self.parent self._set_parent_with_dispatch(table) - elif columns and \ - isinstance(columns[0], Column) and \ - columns[0].table is not None: - self._set_parent_with_dispatch(columns[0].table) + + def _append_element(self, column, fk): + self.columns.add(column) + self.elements.append(fk) + + @property + def _elements(self): + # legacy - provide a dictionary view of (column_key, fk) + return util.OrderedDict( + zip(self.column_keys, self.elements) + ) @property def _referred_schema(self): - for elem in self._elements.values(): + for elem in self.elements: return elem._referred_schema else: return None + @property + def referred_table(self): + """The :class:`.Table` object to which this + :class:`.ForeignKeyConstraint` references. + + This is a dynamically calculated attribute which may not be available + if the constraint and/or parent table is not yet associated with + a metadata collection that contains the referred table. + + .. versionadded:: 1.0.0 + + """ + return self.elements[0].column.table + def _validate_dest_table(self, table): table_keys = set([elem._table_key() - for elem in self._elements.values()]) + for elem in self.elements]) if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( @@ -2609,53 +2747,58 @@ class ForeignKeyConstraint(Constraint): )) @property - def _col_description(self): - return ", ".join(self._elements) + def column_keys(self): + """Return a list of string keys representing the local + columns in this :class:`.ForeignKeyConstraint`. - @property - def columns(self): - return list(self._elements) + This list is either the original string arguments sent + to the constructor of the :class:`.ForeignKeyConstraint`, + or if the constraint has been initialized with :class:`.Column` + objects, is the string .key of each element. + + .. versionadded:: 1.0.0 + + """ + if hasattr(self, "parent"): + return self.columns.keys() + else: + return [ + col.key if isinstance(col, ColumnElement) + else str(col) for col in self._pending_colargs + ] @property - def elements(self): - return list(self._elements.values()) + def _col_description(self): + return ", ".join(self.column_keys) def _set_parent(self, table): - super(ForeignKeyConstraint, self)._set_parent(table) - - self._validate_dest_table(table) + Constraint._set_parent(self, table) - for col, fk in self._elements.items(): - # string-specified column names now get - # resolved to Column objects - if isinstance(col, util.string_types): - try: - col = table.c[col] - except KeyError: - raise exc.ArgumentError( - "Can't create ForeignKeyConstraint " - "on table '%s': no column " - "named '%s' is present." % (table.description, col)) + try: + ColumnCollectionConstraint._set_parent(self, table) + except KeyError as ke: + raise exc.ArgumentError( + "Can't create ForeignKeyConstraint " + "on table '%s': no column " + "named '%s' is present." % (table.description, ke.args[0])) + for col, fk in zip(self.columns, self.elements): if not hasattr(fk, 'parent') or \ fk.parent is not col: fk._set_parent_with_dispatch(col) - if self.use_alter: - def supports_alter(ddl, event, schema_item, bind, **kw): - return table in set(kw['tables']) and \ - bind.dialect.supports_alter - - event.listen(table.metadata, "after_create", - ddl.AddConstraint(self, on=supports_alter)) - event.listen(table.metadata, "before_drop", - ddl.DropConstraint(self, on=supports_alter)) + self._validate_dest_table(table) - def copy(self, schema=None, **kw): + def copy(self, schema=None, target_table=None, **kw): fkc = ForeignKeyConstraint( - [x.parent.key for x in self._elements.values()], - [x._get_colspec(schema=schema) - for x in self._elements.values()], + [x.parent.key for x in self.elements], + [x._get_colspec( + schema=schema, + table_name=target_table.name + if target_table is not None + and x._table_key() == x.parent.table.key + else None) + for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, @@ -2666,8 +2809,8 @@ class ForeignKeyConstraint(Constraint): match=self.match ) for self_fk, other_fk in zip( - self._elements.values(), - fkc._elements.values()): + self.elements, + fkc.elements): self_fk._schema_item_copy(other_fk) return self._schema_item_copy(fkc) @@ -2968,12 +3111,6 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): ) ) self.table = table - for c in self.columns: - if c.table != self.table: - raise exc.ArgumentError( - "Column '%s' is not part of table '%s'." % - (c, self.table.description) - ) table.indexes.add(self) self.expressions = [ @@ -3288,12 +3425,30 @@ class MetaData(SchemaItem): order in which they can be created. To get the order in which the tables would be dropped, use the ``reversed()`` Python built-in. + .. warning:: + + The :attr:`.sorted_tables` accessor cannot by itself accommodate + automatic resolution of dependency cycles between tables, which + are usually caused by mutually dependent foreign key constraints. + To resolve these cycles, either the + :paramref:`.ForeignKeyConstraint.use_alter` parameter may be appled + to those constraints, or use the + :func:`.schema.sort_tables_and_constraints` function which will break + out foreign key constraints involved in cycles separately. + .. seealso:: + :func:`.schema.sort_tables` + + :func:`.schema.sort_tables_and_constraints` + :attr:`.MetaData.tables` :meth:`.Inspector.get_table_names` + :meth:`.Inspector.get_sorted_table_and_fkc_names` + + """ return ddl.sort_tables(self.tables.values()) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8198a6733..f848ef6db 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1,5 +1,5 @@ # sql/selectable.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -43,9 +43,10 @@ def _interpret_as_from(element): {"expr": util.ellipses_string(element)}) return TextClause(util.text_type(element)) - elif hasattr(insp, "selectable"): + try: return insp.selectable - raise exc.ArgumentError("FROM expression expected") + except AttributeError: + raise exc.ArgumentError("FROM expression expected") def _interpret_as_select(element): @@ -171,6 +172,79 @@ class Selectable(ClauseElement): return self +class HasPrefixes(object): + _prefixes = () + + @_generative + def prefix_with(self, *expr, **kw): + """Add one or more expressions following the statement keyword, i.e. + SELECT, INSERT, UPDATE, or DELETE. Generative. + + This is used to support backend-specific prefix keywords such as those + provided by MySQL. + + E.g.:: + + stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") + + Multiple prefixes can be specified by multiple calls + to :meth:`.prefix_with`. + + :param \*expr: textual or :class:`.ClauseElement` construct which + will be rendered following the INSERT, UPDATE, or DELETE + keyword. + :param \**kw: A single keyword 'dialect' is accepted. This is an + optional string dialect name which will + limit rendering of this prefix to only that dialect. + + """ + dialect = kw.pop('dialect', None) + if kw: + raise exc.ArgumentError("Unsupported argument(s): %s" % + ",".join(kw)) + self._setup_prefixes(expr, dialect) + + def _setup_prefixes(self, prefixes, dialect=None): + self._prefixes = self._prefixes + tuple( + [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) + + +class HasSuffixes(object): + _suffixes = () + + @_generative + def suffix_with(self, *expr, **kw): + """Add one or more expressions following the statement as a whole. + + This is used to support backend-specific suffix keywords on + certain constructs. + + E.g.:: + + stmt = select([col1, col2]).cte().suffix_with( + "cycle empno set y_cycle to 1 default 0", dialect="oracle") + + Multiple prefixes can be specified by multiple calls + to :meth:`.suffix_with`. + + :param \*expr: textual or :class:`.ClauseElement` construct which + will be rendered following the target clause. + :param \**kw: A single keyword 'dialect' is accepted. This is an + optional string dialect name which will + limit rendering of this suffix to only that dialect. + + """ + dialect = kw.pop('dialect', None) + if kw: + raise exc.ArgumentError("Unsupported argument(s): %s" % + ",".join(kw)) + self._setup_suffixes(expr, dialect) + + def _setup_suffixes(self, suffixes, dialect=None): + self._suffixes = self._suffixes + tuple( + [(_literal_as_text(p, warn=False), dialect) for p in suffixes]) + + class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -874,7 +948,7 @@ class Join(FromClause): """return an alias of this :class:`.Join`. The default behavior here is to first produce a SELECT - construct from this :class:`.Join`, then to produce a + construct from this :class:`.Join`, then to produce an :class:`.Alias` from that. So given a join of the form:: j = table_a.join(table_b, table_a.c.id == table_b.c.a_id) @@ -1088,7 +1162,7 @@ class Alias(FromClause): return self.element.bind -class CTE(Alias): +class CTE(Generative, HasSuffixes, Alias): """Represent a Common Table Expression. The :class:`.CTE` object is obtained using the @@ -1104,10 +1178,13 @@ class CTE(Alias): name=None, recursive=False, _cte_alias=None, - _restates=frozenset()): + _restates=frozenset(), + _suffixes=None): self.recursive = recursive self._cte_alias = _cte_alias self._restates = _restates + if _suffixes: + self._suffixes = _suffixes super(CTE, self).__init__(selectable, name=name) def alias(self, name=None, flat=False): @@ -1116,6 +1193,7 @@ class CTE(Alias): name=name, recursive=self.recursive, _cte_alias=self, + _suffixes=self._suffixes ) def union(self, other): @@ -1123,7 +1201,8 @@ class CTE(Alias): self.original.union(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]) + _restates=self._restates.union([self]), + _suffixes=self._suffixes ) def union_all(self, other): @@ -1131,7 +1210,8 @@ class CTE(Alias): self.original.union_all(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]) + _restates=self._restates.union([self]), + _suffixes=self._suffixes ) @@ -2118,44 +2198,7 @@ class CompoundSelect(GenerativeSelect): bind = property(bind, _set_bind) -class HasPrefixes(object): - _prefixes = () - - @_generative - def prefix_with(self, *expr, **kw): - """Add one or more expressions following the statement keyword, i.e. - SELECT, INSERT, UPDATE, or DELETE. Generative. - - This is used to support backend-specific prefix keywords such as those - provided by MySQL. - - E.g.:: - - stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") - - Multiple prefixes can be specified by multiple calls - to :meth:`.prefix_with`. - - :param \*expr: textual or :class:`.ClauseElement` construct which - will be rendered following the INSERT, UPDATE, or DELETE - keyword. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will - limit rendering of this prefix to only that dialect. - - """ - dialect = kw.pop('dialect', None) - if kw: - raise exc.ArgumentError("Unsupported argument(s): %s" % - ",".join(kw)) - self._setup_prefixes(expr, dialect) - - def _setup_prefixes(self, prefixes, dialect=None): - self._prefixes = self._prefixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) - - -class Select(HasPrefixes, GenerativeSelect): +class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """Represents a ``SELECT`` statement. """ @@ -2163,6 +2206,7 @@ class Select(HasPrefixes, GenerativeSelect): __visit_name__ = 'select' _prefixes = () + _suffixes = () _hints = util.immutabledict() _statement_hints = () _distinct = False @@ -2180,6 +2224,7 @@ class Select(HasPrefixes, GenerativeSelect): having=None, correlate=True, prefixes=None, + suffixes=None, **kwargs): """Construct a new :class:`.Select`. @@ -2425,6 +2470,9 @@ class Select(HasPrefixes, GenerativeSelect): if prefixes: self._setup_prefixes(prefixes) + if suffixes: + self._setup_suffixes(suffixes) + GenerativeSelect.__init__(self, **kwargs) @property @@ -2437,21 +2485,20 @@ class Select(HasPrefixes, GenerativeSelect): seen = set() translate = self._from_cloned - def add(items): - for item in items: - if item is self: - raise exc.InvalidRequestError( - "select() construct refers to itself as a FROM") - if translate and item in translate: - item = translate[item] - if not seen.intersection(item._cloned_set): - froms.append(item) - seen.update(item._cloned_set) - - add(_from_objects(*self._raw_columns)) - if self._whereclause is not None: - add(_from_objects(self._whereclause)) - add(self._from_obj) + for item in itertools.chain( + _from_objects(*self._raw_columns), + _from_objects(self._whereclause) + if self._whereclause is not None else (), + self._from_obj + ): + if item is self: + raise exc.InvalidRequestError( + "select() construct refers to itself as a FROM") + if translate and item in translate: + item = translate[item] + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) return froms @@ -2633,7 +2680,8 @@ class Select(HasPrefixes, GenerativeSelect): only_froms = dict( (c.key, c) for c in _select_iterables(self.froms) if c._allow_label_resolve) - with_cols.update(only_froms) + for key, value in only_froms.items(): + with_cols.setdefault(key, value) return with_cols, only_froms diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 2729bc83e..7e2e601e2 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1,5 +1,5 @@ # sql/sqltypes.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -14,7 +14,6 @@ import codecs from .type_api import TypeEngine, TypeDecorator, to_instance from .elements import quoted_name, type_coerce, _defer_name -from .default_comparator import _DefaultColumnComparator from .. import exc, util, processors from .base import _bind_or_error, SchemaEventTarget from . import operators @@ -894,7 +893,7 @@ class LargeBinary(_Binary): :param length: optional, a length for the column for use in DDL statements, for those BLOB types that accept a length - (i.e. MySQL). It does *not* produce a small BINARY/VARBINARY + (i.e. MySQL). It does *not* produce a *lengthed* BINARY/VARBINARY type - use the BINARY/VARBINARY types specifically for those. May be safely omitted if no ``CREATE TABLE`` will be issued. Certain databases may require a @@ -939,7 +938,7 @@ class SchemaType(SchemaEventTarget): """ def __init__(self, name=None, schema=None, metadata=None, - inherit_schema=False, quote=None): + inherit_schema=False, quote=None, _create_events=True): if name is not None: self.name = quoted_name(name, quote) else: @@ -947,8 +946,9 @@ class SchemaType(SchemaEventTarget): self.schema = schema self.metadata = metadata self.inherit_schema = inherit_schema + self._create_events = _create_events - if self.metadata: + if _create_events and self.metadata: event.listen( self.metadata, "before_create", @@ -967,6 +967,9 @@ class SchemaType(SchemaEventTarget): if self.inherit_schema: self.schema = table.schema + if not self._create_events: + return + event.listen( table, "before_create", @@ -993,19 +996,18 @@ class SchemaType(SchemaEventTarget): ) def copy(self, **kw): - return self.adapt(self.__class__) + return self.adapt(self.__class__, _create_events=True) def adapt(self, impltype, **kw): schema = kw.pop('schema', self.schema) + metadata = kw.pop('metadata', self.metadata) + _create_events = kw.pop('_create_events', False) - # don't associate with MetaData as the hosting type - # is already associated with it, avoid creating event - # listeners - metadata = kw.pop('metadata', None) return impltype(name=self.name, schema=schema, - metadata=metadata, inherit_schema=self.inherit_schema, + metadata=metadata, + _create_events=_create_events, **kw) @property @@ -1149,6 +1151,7 @@ class Enum(String, SchemaType): def __repr__(self): return util.generic_repr(self, + additional_kw=[('native_enum', True)], to_inspect=[Enum, SchemaType], ) @@ -1165,13 +1168,15 @@ class Enum(String, SchemaType): type_coerce(column, self).in_(self.enums), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint) + self._should_create_constraint), + _type_bound=True ) assert e.table is table def adapt(self, impltype, **kw): schema = kw.pop('schema', self.schema) - metadata = kw.pop('metadata', None) + metadata = kw.pop('metadata', self.metadata) + _create_events = kw.pop('_create_events', False) if issubclass(impltype, Enum): return impltype(name=self.name, schema=schema, @@ -1179,9 +1184,11 @@ class Enum(String, SchemaType): convert_unicode=self.convert_unicode, native_enum=self.native_enum, inherit_schema=self.inherit_schema, + _create_events=_create_events, *self.enums, **kw) else: + # TODO: why would we be here? return super(Enum, self).adapt(impltype, **kw) @@ -1277,7 +1284,8 @@ class Boolean(TypeEngine, SchemaType): __visit_name__ = 'boolean' - def __init__(self, create_constraint=True, name=None): + def __init__( + self, create_constraint=True, name=None, _create_events=True): """Construct a Boolean. :param create_constraint: defaults to True. If the boolean @@ -1290,6 +1298,7 @@ class Boolean(TypeEngine, SchemaType): """ self.create_constraint = create_constraint self.name = name + self._create_events = _create_events def _should_create_constraint(self, compiler): return not compiler.dialect.supports_native_boolean @@ -1303,7 +1312,8 @@ class Boolean(TypeEngine, SchemaType): type_coerce(column, self).in_([0, 1]), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint) + self._should_create_constraint), + _type_bound=True ) assert e.table is table @@ -1654,10 +1664,26 @@ class NullType(TypeEngine): comparator_factory = Comparator +class MatchType(Boolean): + """Refers to the return type of the MATCH operator. + + As the :meth:`.ColumnOperators.match` is probably the most open-ended + operator in generic SQLAlchemy Core, we can't assume the return type + at SQL evaluation time, as MySQL returns a floating point, not a boolean, + and other backends might do something different. So this type + acts as a placeholder, currently subclassing :class:`.Boolean`. + The type allows dialects to inject result-processing functionality + if needed, and on MySQL will return floating-point values. + + .. versionadded:: 1.0.0 + + """ + NULLTYPE = NullType() BOOLEANTYPE = Boolean() STRINGTYPE = String() INTEGERTYPE = Integer() +MATCHTYPE = MatchType() _type_map = { int: Integer(), @@ -1685,21 +1711,7 @@ type_api.BOOLEANTYPE = BOOLEANTYPE type_api.STRINGTYPE = STRINGTYPE type_api.INTEGERTYPE = INTEGERTYPE type_api.NULLTYPE = NULLTYPE +type_api.MATCHTYPE = MATCHTYPE type_api._type_map = _type_map -# this one, there's all kinds of ways to play it, but at the EOD -# there's just a giant dependency cycle between the typing system and -# the expression element system, as you might expect. We can use -# importlaters or whatnot, but the typing system just necessarily has -# to have some kind of connection like this. right now we're injecting the -# _DefaultColumnComparator implementation into the TypeEngine.Comparator -# interface. Alternatively TypeEngine.Comparator could have an "impl" -# injected, though just injecting the base is simpler, error free, and more -# performant. - - -class Comparator(_DefaultColumnComparator): - BOOLEANTYPE = BOOLEANTYPE - -TypeEngine.Comparator.__bases__ = ( - Comparator, ) + TypeEngine.Comparator.__bases__ +TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 77c6e1b1e..4660850bd 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1,5 +1,5 @@ # sql/types_api.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -12,13 +12,14 @@ from .. import exc, util from . import operators -from .visitors import Visitable +from .visitors import Visitable, VisitableType # these are back-assigned by sqltypes. BOOLEANTYPE = None INTEGERTYPE = None NULLTYPE = None STRINGTYPE = None +MATCHTYPE = None class TypeEngine(Visitable): @@ -45,9 +46,51 @@ class TypeEngine(Visitable): """ + __slots__ = 'expr', 'type' + + default_comparator = None def __init__(self, expr): self.expr = expr + self.type = expr.type + + @util.dependencies('sqlalchemy.sql.default_comparator') + def operate(self, default_comparator, op, *other, **kwargs): + o = default_comparator.operator_lookup[op.__name__] + return o[0](self.expr, op, *(other + o[1:]), **kwargs) + + @util.dependencies('sqlalchemy.sql.default_comparator') + def reverse_operate(self, default_comparator, op, other, **kwargs): + o = default_comparator.operator_lookup[op.__name__] + return o[0](self.expr, op, other, + reverse=True, *o[1:], **kwargs) + + def _adapt_expression(self, op, other_comparator): + """evaluate the return type of <self> <op> <othertype>, + and apply any adaptations to the given operator. + + This method determines the type of a resulting binary expression + given two source types and an operator. For example, two + :class:`.Column` objects, both of the type :class:`.Integer`, will + produce a :class:`.BinaryExpression` that also has the type + :class:`.Integer` when compared via the addition (``+``) operator. + However, using the addition operator with an :class:`.Integer` + and a :class:`.Date` object will produce a :class:`.Date`, assuming + "days delta" behavior by the database (in reality, most databases + other than Postgresql don't accept this particular operation). + + The method returns a tuple of the form <operator>, <type>. + The resulting operator and type will be those applied to the + resulting :class:`.BinaryExpression` as the final operator and the + right-hand side of the expression. + + Note that only a subset of operators make usage of + :meth:`._adapt_expression`, + including math operators and user-defined operators, but not + boolean comparison or special SQL keywords like MATCH or BETWEEN. + + """ + return op, other_comparator.type def __reduce__(self): return _reconstitute_comparator, (self.expr, ) @@ -252,7 +295,7 @@ class TypeEngine(Visitable): The construction of :meth:`.TypeEngine.with_variant` is always from the "fallback" type to that which is dialect specific. The returned type is an instance of :class:`.Variant`, which - itself provides a :meth:`~sqlalchemy.types.Variant.with_variant` + itself provides a :meth:`.Variant.with_variant` that can be called repeatedly. :param type_: a :class:`.TypeEngine` that will be selected @@ -417,7 +460,11 @@ class TypeEngine(Visitable): return util.generic_repr(self) -class UserDefinedType(TypeEngine): +class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType): + pass + + +class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)): """Base for user defined types. This should be the base of new types. Note that @@ -430,7 +477,7 @@ class UserDefinedType(TypeEngine): def __init__(self, precision = 8): self.precision = precision - def get_col_spec(self): + def get_col_spec(self, **kw): return "MYTYPE(%s)" % self.precision def bind_processor(self, dialect): @@ -450,10 +497,26 @@ class UserDefinedType(TypeEngine): Column('data', MyType(16)) ) + The ``get_col_spec()`` method will in most cases receive a keyword + argument ``type_expression`` which refers to the owning expression + of the type as being compiled, such as a :class:`.Column` or + :func:`.cast` construct. This keyword is only sent if the method + accepts keyword arguments (e.g. ``**kw``) in its argument signature; + introspection is used to check for this in order to support legacy + forms of this function. + + .. versionadded:: 1.0.0 the owning expression is passed to + the ``get_col_spec()`` method via the keyword argument + ``type_expression``, if it receives ``**kw`` in its signature. + """ __visit_name__ = "user_defined" + ensure_kwarg = 'get_col_spec' + class Comparator(TypeEngine.Comparator): + __slots__ = () + def _adapt_expression(self, op, other_comparator): if hasattr(self.type, 'adapt_operator'): util.warn_deprecated( @@ -617,6 +680,7 @@ class TypeDecorator(TypeEngine): """ class Comparator(TypeEngine.Comparator): + __slots__ = () def operate(self, op, *other, **kwargs): kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types @@ -630,9 +694,13 @@ class TypeDecorator(TypeEngine): @property def comparator_factory(self): - return type("TDComparator", - (TypeDecorator.Comparator, self.impl.comparator_factory), - {}) + if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: + return self.impl.comparator_factory + else: + return type("TDComparator", + (TypeDecorator.Comparator, + self.impl.comparator_factory), + {}) def _gen_dialect_impl(self, dialect): """ diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index fbbe15da3..bec5b5824 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,5 +1,5 @@ # sql/util.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index bb525744a..0540ac5d3 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,5 +1,5 @@ # sql/visitors.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -51,6 +51,7 @@ class VisitableType(type): Classes having no __visit_name__ attribute will remain unaffected. """ + def __init__(cls, clsname, bases, clsdict): if clsname != 'Visitable' and \ hasattr(cls, '__visit_name__'): @@ -212,12 +213,19 @@ def iterate(obj, opts): traversal is configured to be breadth-first. """ + # fasttrack for atomic elements like columns + children = obj.get_children(**opts) + if not children: + return [obj] + + traversal = deque() stack = deque([obj]) while stack: t = stack.popleft() - yield t + traversal.append(t) for c in t.get_children(**opts): stack.append(c) + return iter(traversal) def iterate_depthfirst(obj, opts): @@ -226,6 +234,11 @@ def iterate_depthfirst(obj, opts): traversal is configured to be depth-first. """ + # fasttrack for atomic elements like columns + children = obj.get_children(**opts) + if not children: + return [obj] + stack = deque([obj]) traversal = deque() while stack: diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 1f37b4b45..bf83e9673 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -1,5 +1,5 @@ # testing/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -23,7 +23,8 @@ from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ assert_raises_message, AssertsCompiledSQL, ComparesTables, \ AssertsExecutionResults, expect_deprecated, expect_warnings -from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict +from .util import run_as_contextmanager, rowset, fail, \ + provide_metadata, adict, force_drop_names crashes = skip diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bf7c27a89..e5249c296 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -1,5 +1,5 @@ # testing/assertions.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -22,7 +22,7 @@ import contextlib from . import mock -def expect_warnings(*messages): +def expect_warnings(*messages, **kw): """Context manager which expects one or more warnings. With no arguments, squelches all SAWarnings emitted via @@ -30,17 +30,21 @@ def expect_warnings(*messages): pass string expressions that will match selected warnings via regex; all non-matching warnings are sent through. + The expect version **asserts** that the warnings were in fact seen. + Note that the test suite sets SAWarning warnings to raise exceptions. """ - return _expect_warnings(sa_exc.SAWarning, messages) + return _expect_warnings(sa_exc.SAWarning, messages, **kw) @contextlib.contextmanager -def expect_warnings_on(db, *messages): +def expect_warnings_on(db, *messages, **kw): """Context manager which expects one or more warnings on specific dialects. + The expect version **asserts** that the warnings were in fact seen. + """ spec = db_spec(db) @@ -49,23 +53,28 @@ def expect_warnings_on(db, *messages): elif not _is_excluded(*db): yield else: - with expect_warnings(*messages): + with expect_warnings(*messages, **kw): yield def emits_warning(*messages): - """Decorator form of expect_warnings().""" + """Decorator form of expect_warnings(). + + Note that emits_warning does **not** assert that the warnings + were in fact seen. + + """ @decorator def decorate(fn, *args, **kw): - with expect_warnings(*messages): + with expect_warnings(assert_=False, *messages): return fn(*args, **kw) return decorate -def expect_deprecated(*messages): - return _expect_warnings(sa_exc.SADeprecationWarning, messages) +def expect_deprecated(*messages, **kw): + return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw) def emits_warning_on(db, *messages): @@ -74,6 +83,10 @@ def emits_warning_on(db, *messages): With no arguments, squelches all SAWarning failures. Or pass one or more strings; these will be matched to the root of the warning description by warnings.filterwarnings(). + + Note that emits_warning_on does **not** assert that the warnings + were in fact seen. + """ @decorator def decorate(fn, *args, **kw): @@ -93,19 +106,28 @@ def uses_deprecated(*messages): As a special case, you may pass a function name prefixed with // and it will be re-written as needed to match the standard warning verbiage emitted by the sqlalchemy.util.deprecated decorator. + + Note that uses_deprecated does **not** assert that the warnings + were in fact seen. + """ @decorator def decorate(fn, *args, **kw): - with expect_deprecated(*messages): + with expect_deprecated(*messages, assert_=False): return fn(*args, **kw) return decorate @contextlib.contextmanager -def _expect_warnings(exc_cls, messages): +def _expect_warnings(exc_cls, messages, regex=True, assert_=True): + + if regex: + filters = [re.compile(msg, re.I) for msg in messages] + else: + filters = messages - filters = [re.compile(msg, re.I) for msg in messages] + seen = set(filters) real_warn = warnings.warn @@ -117,7 +139,9 @@ def _expect_warnings(exc_cls, messages): return for filter_ in filters: - if filter_.match(msg): + if (regex and filter_.match(msg)) or \ + (not regex and filter_ == msg): + seen.discard(filter_) break else: real_warn(msg, exception, *arg, **kw) @@ -125,6 +149,10 @@ def _expect_warnings(exc_cls, messages): with mock.patch("warnings.warn", our_warn): yield + if assert_: + assert not seen, "Warnings were not seen: %s" % \ + ", ".join("%r" % (s.pattern if regex else s) for s in seen) + def global_cleanup_assertions(): """Check things that have to be finalized at the end of a test suite. @@ -229,6 +257,7 @@ class AssertsCompiledSQL(object): def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, checkpositional=None, + check_prefetch=None, use_default_dialect=False, allow_dialect_select=False, literal_binds=False): @@ -289,6 +318,8 @@ class AssertsCompiledSQL(object): if checkpositional is not None: p = c.construct_params(params) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + if check_prefetch is not None: + eq_(c.prefetch, check_prefetch) class ComparesTables(object): @@ -405,29 +436,27 @@ class AssertsExecutionResults(object): cls.__name__, repr(expected_item))) return True + def sql_execution_asserter(self, db=None): + if db is None: + from . import db as db + + return assertsql.assert_engine(db) + def assert_sql_execution(self, db, callable_, *rules): - assertsql.asserter.add_rules(rules) - try: + with self.sql_execution_asserter(db) as asserter: callable_() - assertsql.asserter.statement_complete() - finally: - assertsql.asserter.clear_rules() + asserter.assert_(*rules) - def assert_sql(self, db, callable_, list_, with_sequences=None): - if (with_sequences is not None and - config.db.dialect.supports_sequences): - rules = with_sequences - else: - rules = list_ + def assert_sql(self, db, callable_, rules): newrules = [] for rule in rules: if isinstance(rule, dict): newrule = assertsql.AllOf(*[ - assertsql.ExactSQL(k, v) for k, v in rule.items() + assertsql.CompiledSQL(k, v) for k, v in rule.items() ]) else: - newrule = assertsql.ExactSQL(*rule) + newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) self.assert_sql_execution(db, callable_, *newrules) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index bcc999fe3..a596d9743 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -1,5 +1,5 @@ # testing/assertsql.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -8,84 +8,141 @@ from ..engine.default import DefaultDialect from .. import util import re +import collections +import contextlib +from .. import event +from sqlalchemy.schema import _DDLCompiles +from sqlalchemy.engine.util import _distill_params class AssertRule(object): - def process_execute(self, clauseelement, *multiparams, **params): - pass + is_consumed = False + errormessage = None + consume_statement = True - def process_cursor_execute(self, statement, parameters, context, - executemany): + def process_statement(self, execute_observed): pass - def is_consumed(self): - """Return True if this rule has been consumed, False if not. - - Should raise an AssertionError if this rule's condition has - definitely failed. - - """ - - raise NotImplementedError() + def no_more_statements(self): + assert False, 'All statements are complete, but pending '\ + 'assertion rules remain' - def rule_passed(self): - """Return True if the last test of this rule passed, False if - failed, None if no test was applied.""" - raise NotImplementedError() - - def consume_final(self): - """Return True if this rule has been consumed. - - Should raise an AssertionError if this rule's condition has not - been consumed or has failed. +class SQLMatchRule(AssertRule): + pass - """ - if self._result is None: - assert False, 'Rule has not been consumed' - return self.is_consumed() +class CursorSQL(SQLMatchRule): + consume_statement = False + def __init__(self, statement, params=None): + self.statement = statement + self.params = params -class SQLMatchRule(AssertRule): - def __init__(self): - self._result = None - self._errmsg = "" + def process_statement(self, execute_observed): + stmt = execute_observed.statements[0] + if self.statement != stmt.statement or ( + self.params is not None and self.params != stmt.parameters): + self.errormessage = \ + "Testing for exact SQL %s parameters %s received %s %s" % ( + self.statement, self.params, + stmt.statement, stmt.parameters + ) + else: + execute_observed.statements.pop(0) + self.is_consumed = True + if not execute_observed.statements: + self.consume_statement = True - def rule_passed(self): - return self._result - def is_consumed(self): - if self._result is None: - return False +class CompiledSQL(SQLMatchRule): - assert self._result, self._errmsg + def __init__(self, statement, params=None): + self.statement = statement + self.params = params - return True + def _compare_sql(self, execute_observed, received_statement): + stmt = re.sub(r'[\n\t]', '', self.statement) + return received_statement == stmt + def _compile_dialect(self, execute_observed): + return DefaultDialect() -class ExactSQL(SQLMatchRule): + def _received_statement(self, execute_observed): + """reconstruct the statement and params in terms + of a target dialect, which for CompiledSQL is just DefaultDialect.""" - def __init__(self, sql, params=None): - SQLMatchRule.__init__(self) - self.sql = sql - self.params = params + context = execute_observed.context + compare_dialect = self._compile_dialect(execute_observed) + if isinstance(context.compiled.statement, _DDLCompiles): + compiled = \ + context.compiled.statement.compile(dialect=compare_dialect) + else: + compiled = ( + context.compiled.statement.compile( + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + inline=context.compiled.inline) + ) + _received_statement = re.sub(r'[\n\t]', '', str(compiled)) + parameters = execute_observed.parameters - def process_cursor_execute(self, statement, parameters, context, - executemany): - if not context: - return - _received_statement = \ - _process_engine_statement(context.unicode_statement, - context) - _received_parameters = context.compiled_parameters + if not parameters: + _received_parameters = [compiled.construct_params()] + else: + _received_parameters = [ + compiled.construct_params(m) for m in parameters] + + return _received_statement, _received_parameters + + def process_statement(self, execute_observed): + context = execute_observed.context + + _received_statement, _received_parameters = \ + self._received_statement(execute_observed) + params = self._all_params(context) + + equivalent = self._compare_sql(execute_observed, _received_statement) + + if equivalent: + if params is not None: + all_params = list(params) + all_received = list(_received_parameters) + while all_params and all_received: + param = dict(all_params.pop(0)) + + for idx, received in enumerate(list(all_received)): + # do a positive compare only + for param_key in param: + # a key in param did not match current + # 'received' + if param_key not in received or \ + received[param_key] != param[param_key]: + break + else: + # all keys in param matched 'received'; + # onto next param + del all_received[idx] + break + else: + # param did not match any entry + # in all_received + equivalent = False + break + if all_params or all_received: + equivalent = False - # TODO: remove this step once all unit tests are migrated, as - # ExactSQL should really be *exact* SQL + if equivalent: + self.is_consumed = True + self.errormessage = None + else: + self.errormessage = self._failure_message(params) % { + 'received_statement': _received_statement, + 'received_parameters': _received_parameters + } - sql = _process_assertion_statement(self.sql, context) - equivalent = _received_statement == sql + def _all_params(self, context): if self.params: if util.callable(self.params): params = self.params(context) @@ -93,127 +150,77 @@ class ExactSQL(SQLMatchRule): params = self.params if not isinstance(params, list): params = [params] - equivalent = equivalent and params \ - == context.compiled_parameters + return params else: - params = {} - self._result = equivalent - if not self._result: - self._errmsg = ( - 'Testing for exact statement %r exact params %r, ' - 'received %r with params %r' % - (sql, params, _received_statement, _received_parameters)) - + return None + + def _failure_message(self, expected_params): + return ( + 'Testing for compiled statement %r partial params %r, ' + 'received %%(received_statement)r with params ' + '%%(received_parameters)r' % ( + self.statement, expected_params + ) + ) -class RegexSQL(SQLMatchRule): +class RegexSQL(CompiledSQL): def __init__(self, regex, params=None): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params - def process_cursor_execute(self, statement, parameters, context, - executemany): - if not context: - return - _received_statement = \ - _process_engine_statement(context.unicode_statement, - context) - _received_parameters = context.compiled_parameters - equivalent = bool(self.regex.match(_received_statement)) - if self.params: - if util.callable(self.params): - params = self.params(context) - else: - params = self.params - if not isinstance(params, list): - params = [params] - - # do a positive compare only - - for param, received in zip(params, _received_parameters): - for k, v in param.items(): - if k not in received or received[k] != v: - equivalent = False - break - else: - params = {} - self._result = equivalent - if not self._result: - self._errmsg = \ - 'Testing for regex %r partial params %r, received %r '\ - 'with params %r' % (self.orig_regex, params, - _received_statement, - _received_parameters) - + def _failure_message(self, expected_params): + return ( + 'Testing for compiled statement ~%r partial params %r, ' + 'received %%(received_statement)r with params ' + '%%(received_parameters)r' % ( + self.orig_regex, expected_params + ) + ) -class CompiledSQL(SQLMatchRule): + def _compare_sql(self, execute_observed, received_statement): + return bool(self.regex.match(received_statement)) - def __init__(self, statement, params=None): - SQLMatchRule.__init__(self) - self.statement = statement - self.params = params - def process_cursor_execute(self, statement, parameters, context, - executemany): - if not context: - return - from sqlalchemy.schema import _DDLCompiles - _received_parameters = list(context.compiled_parameters) - - # recompile from the context, using the default dialect +class DialectSQL(CompiledSQL): + def _compile_dialect(self, execute_observed): + return execute_observed.context.dialect - if isinstance(context.compiled.statement, _DDLCompiles): - compiled = \ - context.compiled.statement.compile(dialect=DefaultDialect()) + def _received_statement(self, execute_observed): + received_stmt, received_params = super(DialectSQL, self).\ + _received_statement(execute_observed) + for real_stmt in execute_observed.statements: + if real_stmt.statement == received_stmt: + break else: - compiled = ( - context.compiled.statement.compile( - dialect=DefaultDialect(), - column_keys=context.compiled.column_keys) - ) - _received_statement = re.sub(r'[\n\t]', '', str(compiled)) - equivalent = self.statement == _received_statement - if self.params: - if util.callable(self.params): - params = self.params(context) - else: - params = self.params - if not isinstance(params, list): - params = [params] - else: - params = list(params) - all_params = list(params) - all_received = list(_received_parameters) - while params: - param = dict(params.pop(0)) - for k, v in context.compiled.params.items(): - param.setdefault(k, v) - if param not in _received_parameters: - equivalent = False - break - else: - _received_parameters.remove(param) - if _received_parameters: - equivalent = False + raise AssertionError( + "Can't locate compiled statement %r in list of " + "statements actually invoked" % received_stmt) + return received_stmt, execute_observed.context.compiled_parameters + + def _compare_sql(self, execute_observed, received_statement): + stmt = re.sub(r'[\n\t]', '', self.statement) + + # convert our comparison statement to have the + # paramstyle of the received + paramstyle = execute_observed.context.dialect.paramstyle + if paramstyle == 'pyformat': + stmt = re.sub( + r':([\w_]+)', r"%(\1)s", stmt) else: - params = {} - all_params = {} - all_received = [] - self._result = equivalent - if not self._result: - print('Testing for compiled statement %r partial params ' - '%r, received %r with params %r' % - (self.statement, all_params, - _received_statement, all_received)) - self._errmsg = ( - 'Testing for compiled statement %r partial params %r, ' - 'received %r with params %r' % - (self.statement, all_params, - _received_statement, all_received)) - - # print self._errmsg + # positional params + repl = None + if paramstyle == 'qmark': + repl = "?" + elif paramstyle == 'format': + repl = r"%s" + elif paramstyle == 'numeric': + repl = None + stmt = re.sub(r':([\w_]+)', repl, stmt) + + return received_statement == stmt class CountStatements(AssertRule): @@ -222,21 +229,13 @@ class CountStatements(AssertRule): self.count = count self._statement_count = 0 - def process_execute(self, clauseelement, *multiparams, **params): + def process_statement(self, execute_observed): self._statement_count += 1 - def process_cursor_execute(self, statement, parameters, context, - executemany): - pass - - def is_consumed(self): - return False - - def consume_final(self): - assert self.count == self._statement_count, \ - 'desired statement count %d does not match %d' \ - % (self.count, self._statement_count) - return True + def no_more_statements(self): + if self.count != self._statement_count: + assert False, 'desired statement count %d does not match %d' \ + % (self.count, self._statement_count) class AllOf(AssertRule): @@ -244,116 +243,113 @@ class AllOf(AssertRule): def __init__(self, *rules): self.rules = set(rules) - def process_execute(self, clauseelement, *multiparams, **params): - for rule in self.rules: - rule.process_execute(clauseelement, *multiparams, **params) - - def process_cursor_execute(self, statement, parameters, context, - executemany): - for rule in self.rules: - rule.process_cursor_execute(statement, parameters, context, - executemany) - - def is_consumed(self): - if not self.rules: - return True + def process_statement(self, execute_observed): for rule in list(self.rules): - if rule.rule_passed(): # a rule passed, move on - self.rules.remove(rule) - return len(self.rules) == 0 - return False + rule.errormessage = None + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.discard(rule) + if not self.rules: + self.is_consumed = True + break + elif not rule.errormessage: + # rule is not done yet + self.errormessage = None + break + else: + self.errormessage = list(self.rules)[0].errormessage - def rule_passed(self): - return self.is_consumed() - def consume_final(self): - return len(self.rules) == 0 +class Or(AllOf): + def process_statement(self, execute_observed): + for rule in self.rules: + rule.process_statement(execute_observed) + if rule.is_consumed: + self.is_consumed = True + break + else: + self.errormessage = list(self.rules)[0].errormessage -class Or(AllOf): - def __init__(self, *rules): - self.rules = set(rules) - self._consume_final = False - def is_consumed(self): - if not self.rules: - return True - for rule in list(self.rules): - if rule.rule_passed(): # a rule passed - self._consume_final = True - return True - return False +class SQLExecuteObserved(object): + def __init__(self, context, clauseelement, multiparams, params): + self.context = context + self.clauseelement = clauseelement + self.parameters = _distill_params(multiparams, params) + self.statements = [] - def consume_final(self): - assert self._consume_final, "Unsatisified rules remain" +class SQLCursorExecuteObserved( + collections.namedtuple( + "SQLCursorExecuteObserved", + ["statement", "parameters", "context", "executemany"]) +): + pass -def _process_engine_statement(query, context): - if util.jython: - # oracle+zxjdbc passes a PyStatement when returning into +class SQLAsserter(object): + def __init__(self): + self.accumulated = [] - query = str(query) - if context.engine.name == 'mssql' \ - and query.endswith('; select scope_identity()'): - query = query[:-25] - query = re.sub(r'\n', '', query) - return query + def _close(self): + self._final = self.accumulated + del self.accumulated + def assert_(self, *rules): + rules = list(rules) + observed = list(self._final) -def _process_assertion_statement(query, context): - paramstyle = context.dialect.paramstyle - if paramstyle == 'named': - pass - elif paramstyle == 'pyformat': - query = re.sub(r':([\w_]+)', r"%(\1)s", query) - else: - # positional params - repl = None - if paramstyle == 'qmark': - repl = "?" - elif paramstyle == 'format': - repl = r"%s" - elif paramstyle == 'numeric': - repl = None - query = re.sub(r':([\w_]+)', repl, query) + while observed and rules: + rule = rules[0] + rule.process_statement(observed[0]) + if rule.is_consumed: + rules.pop(0) + elif rule.errormessage: + assert False, rule.errormessage - return query + if rule.consume_statement: + observed.pop(0) + if not observed and rules: + rules[0].no_more_statements() + elif not rules and observed: + assert False, "Additional SQL statements remain" -class SQLAssert(object): - rules = None +@contextlib.contextmanager +def assert_engine(engine): + asserter = SQLAsserter() - def add_rules(self, rules): - self.rules = list(rules) + orig = [] - def statement_complete(self): - for rule in self.rules: - if not rule.consume_final(): - assert False, \ - 'All statements are complete, but pending '\ - 'assertion rules remain' - - def clear_rules(self): - del self.rules - - def execute(self, conn, clauseelement, multiparams, params, result): - if self.rules is not None: - if not self.rules: - assert False, \ - 'All rules have been exhausted, but further '\ - 'statements remain' - rule = self.rules[0] - rule.process_execute(clauseelement, *multiparams, **params) - if rule.is_consumed(): - self.rules.pop(0) - - def cursor_execute(self, conn, cursor, statement, parameters, - context, executemany): - if self.rules: - rule = self.rules[0] - rule.process_cursor_execute(statement, parameters, context, - executemany) + @event.listens_for(engine, "before_execute") + def connection_execute(conn, clauseelement, multiparams, params): + # grab the original statement + params before any cursor + # execution + orig[:] = clauseelement, multiparams, params -asserter = SQLAssert() + @event.listens_for(engine, "after_cursor_execute") + def cursor_execute(conn, cursor, statement, parameters, + context, executemany): + if not context: + return + # then grab real cursor statements and associate them all + # around a single context + if asserter.accumulated and \ + asserter.accumulated[-1].context is context: + obs = asserter.accumulated[-1] + else: + obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) + asserter.accumulated.append(obs) + obs.statements.append( + SQLCursorExecuteObserved( + statement, parameters, context, executemany) + ) + + try: + yield asserter + finally: + event.remove(engine, "after_cursor_execute", cursor_execute) + event.remove(engine, "before_execute", connection_execute) + asserter._close() diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 6832eab74..d429c9f4e 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -1,5 +1,5 @@ # testing/config.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 1284f9c2a..3a8303546 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -1,5 +1,5 @@ # testing/engines.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -204,7 +204,6 @@ def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" from sqlalchemy import create_engine - from .assertsql import asserter if not options: use_reaper = True @@ -216,11 +215,12 @@ def testing_engine(url=None, options=None): options = config.db_opts engine = create_engine(url, **options) + engine._has_events = True # enable event blocks, helps with + # profiling + if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 - event.listen(engine, 'after_execute', asserter.execute) - event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) if use_reaper: event.listen(engine.pool, 'connect', testing_reaper.connect) event.listen(engine.pool, 'checkout', testing_reaper.checkout) @@ -280,10 +280,10 @@ class DBAPIProxyCursor(object): """ - def __init__(self, engine, conn): + def __init__(self, engine, conn, *args, **kwargs): self.engine = engine self.connection = conn - self.cursor = conn.cursor() + self.cursor = conn.cursor(*args, **kwargs) def execute(self, stmt, parameters=None, **kw): if parameters: @@ -311,8 +311,8 @@ class DBAPIProxyConnection(object): self.engine = engine self.cursor_cls = cursor_cls - def cursor(self): - return self.cursor_cls(self.engine, self.conn) + def cursor(self, *args, **kwargs): + return self.cursor_cls(self.engine, self.conn, *args, **kwargs) def close(self): self.conn.close() diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index 3e42955e6..65178ea5b 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -1,5 +1,5 @@ # testing/entities.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index f94724608..6aa4bf142 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -1,5 +1,5 @@ # testing/exclusions.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -425,7 +425,7 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) + OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)]) ) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index d86049da7..7b421952f 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -1,5 +1,5 @@ # testing/fixtures.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -192,9 +192,8 @@ class TablesTest(TestBase): def sql_count_(self, count, fn): self.assert_sql_count(self.bind, fn, count) - def sql_eq_(self, callable_, statements, with_sequences=None): - self.assert_sql(self.bind, - callable_, statements, with_sequences) + def sql_eq_(self, callable_, statements): + self.assert_sql(self.bind, callable_, statements) @classmethod def _load_fixtures(cls): diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py index c6a4d4360..be83693cc 100644 --- a/lib/sqlalchemy/testing/mock.py +++ b/lib/sqlalchemy/testing/mock.py @@ -1,5 +1,5 @@ # testing/mock.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index 5a903aae7..7b696ad67 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -1,5 +1,5 @@ # testing/pickleable.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py index 538087770..1ae6e28f5 100644 --- a/lib/sqlalchemy/testing/plugin/noseplugin.py +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -1,5 +1,5 @@ # plugin/noseplugin.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 6696427dc..14cf1eb31 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -1,5 +1,5 @@ # plugin/plugin_base.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -93,7 +93,10 @@ def setup_options(make_option): help="Exclude tests with tag <tag>") make_option("--write-profiles", action="store_true", dest="write_profiles", default=False, - help="Write/update profiling data.") + help="Write/update failing profiling data.") + make_option("--force-write-profiles", action="store_true", + dest="force_write_profiles", default=False, + help="Unconditionally write/update profiling data.") def configure_follower(follower_ident): @@ -291,7 +294,7 @@ def _setup_requirements(argument): @post def _prep_testing_database(options, file_config): - from sqlalchemy.testing import config + from sqlalchemy.testing import config, util from sqlalchemy.testing.exclusions import against from sqlalchemy import schema, inspect @@ -322,19 +325,10 @@ def _prep_testing_database(options, file_config): schema="test_schema") )) - for tname in reversed(inspector.get_table_names( - order_by="foreign_key")): - e.execute(schema.DropTable( - schema.Table(tname, schema.MetaData()) - )) + util.drop_all_tables(e, inspector) if config.requirements.schemas.enabled_for_config(cfg): - for tname in reversed(inspector.get_table_names( - order_by="foreign_key", schema="test_schema")): - e.execute(schema.DropTable( - schema.Table(tname, schema.MetaData(), - schema="test_schema") - )) + util.drop_all_tables(e, inspector, schema=cfg.test_schema) if against(cfg, "postgresql"): from sqlalchemy.dialects import postgresql diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 4bbc8ed9a..fbab4966c 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -84,7 +84,8 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict(list) items[:] = [ item for item in - items if isinstance(item.parent, pytest.Instance)] + items if isinstance(item.parent, pytest.Instance) + and not item.parent.parent.name.startswith("_")] test_classes = set(item.parent for item in items) for test_class in test_classes: for sub_cls in plugin_base.generate_sub_tests( diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index fcb888f86..65fe165cd 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -1,5 +1,5 @@ # testing/profiling.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -42,7 +42,11 @@ class ProfileStatsFile(object): """ def __init__(self, filename): - self.write = ( + self.force_write = ( + config.options is not None and + config.options.force_write_profiles + ) + self.write = self.force_write or ( config.options is not None and config.options.write_profiles ) @@ -115,7 +119,11 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] counts = per_platform['counts'] - counts[-1] = callcount + current_count = per_platform['current_count'] + if current_count < len(counts): + counts[current_count - 1] = callcount + else: + counts[-1] = callcount if self.write: self._write() @@ -218,6 +226,7 @@ def count_functions(variance=0.05): callcount = stats.total_calls expected = _profile_stats.result(callcount) + if expected is None: expected_count = None else: @@ -235,16 +244,17 @@ def count_functions(variance=0.05): deviance = int(callcount * variance) failed = abs(callcount - expected_count) > deviance - if failed: + if failed or _profile_stats.force_write: if _profile_stats.write: _profile_stats.replace(callcount) else: raise AssertionError( "Adjusted function call count %s not within %s%% " - "of expected %s. Rerun with --write-profiles to " + "of expected %s, platform %s. Rerun with " + "--write-profiles to " "regenerate this callcount." % ( callcount, (variance * 100), - expected_count)) + expected_count, _profile_stats.platform_key)) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index da3e3128a..32465c47d 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1,5 +1,5 @@ # testing/requirements.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -323,6 +323,11 @@ class SuiteRequirements(Requirements): return exclusions.closed() @property + def temporary_tables(self): + """target database supports temporary tables""" + return exclusions.open() + + @property def temporary_views(self): """target database supports temporary views""" return exclusions.closed() diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py index 23d7a0a91..92a03061e 100644 --- a/lib/sqlalchemy/testing/runner.py +++ b/lib/sqlalchemy/testing/runner.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # testing/runner.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 9561b1f1e..93b52ad58 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -1,5 +1,5 @@ # testing/schema.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 38519dfb9..70e8a6b17 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -109,7 +109,8 @@ class InsertBehaviorTest(fixtures.TablesTest): self.tables.autoinc_pk.insert(), data="some data" ) - assert r.closed + assert r._soft_closed + assert not r.closed assert r.is_insert assert not r.returns_rows @@ -119,7 +120,8 @@ class InsertBehaviorTest(fixtures.TablesTest): self.tables.autoinc_pk.insert(), data="some data" ) - assert r.closed + assert r._soft_closed + assert not r.closed assert r.is_insert assert not r.returns_rows @@ -128,7 +130,8 @@ class InsertBehaviorTest(fixtures.TablesTest): r = config.db.execute( self.tables.autoinc_pk.insert(), ) - assert r.closed + assert r._soft_closed + assert not r.closed r = config.db.execute( self.tables.autoinc_pk.select(). diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 08b858b47..3edbdeb8c 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -128,6 +128,10 @@ class ComponentReflectionTest(fixtures.TablesTest): DDL("create temporary view user_tmp_v as " "select * from user_tmp") ) + event.listen( + user_tmp, "before_drop", + DDL("drop view user_tmp_v") + ) @classmethod def define_index(cls, metadata, users): @@ -511,6 +515,8 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_temp_table_indexes(self): insp = inspect(self.metadata.bind) indexes = insp.get_indexes('user_tmp') + for ind in indexes: + ind.pop('dialect_options', None) eq_( # TODO: we need to add better filtering for indexes/uq constraints # that are doubled up diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 68dadd0a9..eaf3f03c2 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -86,6 +86,15 @@ class OrderByLabelTest(fixtures.TablesTest): [(7, ), (5, ), (3, )] ) + def test_group_by_composed(self): + table = self.tables.some_table + expr = (table.c.x + table.c.y).label('lx') + stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr) + self._assert_result( + stmt, + [(1, 3), (1, 5), (1, 7)] + ) + class LimitOffsetTest(fixtures.TablesTest): __backend__ = True diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 7b3f721a6..6d6fa094e 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -1,5 +1,5 @@ # testing/util.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -147,6 +147,10 @@ def run_as_contextmanager(ctx, fn, *arg, **kw): simulating the behavior of 'with' to support older Python versions. + This is not necessary anymore as we have placed 2.6 + as minimum Python version, however some tests are still using + this structure. + """ obj = ctx.__enter__() @@ -194,6 +198,25 @@ def provide_metadata(fn, *args, **kw): self.metadata = prev_meta +def force_drop_names(*names): + """Force the given table names to be dropped after test complete, + isolating for foreign key cycles + + """ + from . import config + from sqlalchemy import inspect + + @decorator + def go(fn, *args, **kw): + + try: + return fn(*args, **kw) + finally: + drop_all_tables( + config.db, inspect(config.db), include_names=names) + return go + + class adict(dict): """Dict keys available as attributes. Shadows.""" @@ -207,3 +230,39 @@ class adict(dict): return tuple([self[key] for key in keys]) get_all = __call__ + + +def drop_all_tables(engine, inspector, schema=None, include_names=None): + from sqlalchemy import Column, Table, Integer, MetaData, \ + ForeignKeyConstraint + from sqlalchemy.schema import DropTable, DropConstraint + + if include_names is not None: + include_names = set(include_names) + + with engine.connect() as conn: + for tname, fkcs in reversed( + inspector.get_sorted_table_and_fkc_names(schema=schema)): + if tname: + if include_names is not None and tname not in include_names: + continue + conn.execute(DropTable( + Table(tname, MetaData()) + )) + elif fkcs: + if not engine.dialect.supports_alter: + continue + for tname, fkc in fkcs: + if include_names is not None and \ + tname not in include_names: + continue + tb = Table( + tname, MetaData(), + Column('x', Integer), + Column('y', Integer), + schema=schema + ) + conn.execute(DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fkc) + )) diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 47f1e1404..19b632d34 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -1,5 +1,5 @@ # testing/warnings.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -9,7 +9,7 @@ from __future__ import absolute_import import warnings from .. import exc as sa_exc -import re +from . import assertions def setup_filters(): @@ -22,19 +22,13 @@ def setup_filters(): def assert_warnings(fn, warning_msgs, regex=False): - """Assert that each of the given warnings are emitted by fn.""" - - from .assertions import eq_ - - with warnings.catch_warnings(record=True) as log: - # ensure that nothing is going into __warningregistry__ - warnings.filterwarnings("always") - - result = fn() - for warning in log: - popwarn = warning_msgs.pop(0) - if regex: - assert re.match(popwarn, str(warning.message)) - else: - eq_(popwarn, str(warning.message)) - return result + """Assert that each of the given warnings are emitted by fn. + + Deprecated. Please use assertions.expect_warnings(). + + """ + + with assertions._expect_warnings( + sa_exc.SAWarning, warning_msgs, regex=regex): + return fn() + diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index b49e389ac..9ab92e90b 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -1,5 +1,5 @@ # types.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -51,6 +51,7 @@ from .sql.sqltypes import ( Integer, Interval, LargeBinary, + MatchType, NCHAR, NVARCHAR, NullType, diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index dfed5b90a..d777d2e06 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -1,5 +1,5 @@ # util/__init__.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \ generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \ safe_reraise,\ get_callable_argspec, only_once, attrsetter, ellipses_string, \ - warn_limited + warn_limited, map_bits, MemoizedSlots, EnsureKWArgType from .deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation, inject_docstring_text diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index a1fbc0fa0..4fb12d71b 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -1,5 +1,5 @@ # util/_collections.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -10,9 +10,10 @@ from __future__ import absolute_import import weakref import operator -from .compat import threading, itertools_filterfalse +from .compat import threading, itertools_filterfalse, string_types from . import py2k import types +import collections EMPTY_SET = frozenset() @@ -126,20 +127,6 @@ class _LW(AbstractKeyedTuple): return d -def lightweight_named_tuple(name, fields): - - tp_cls = type(name, (_LW,), {}) - for idx, field in enumerate(fields): - if field is None: - continue - setattr(tp_cls, field, property(operator.itemgetter(idx))) - - tp_cls._real_fields = fields - tp_cls._fields = tuple([f for f in fields if f is not None]) - - return tp_cls - - class ImmutableContainer(object): def _immutable(self, *arg, **kw): raise TypeError("%s object is immutable" % self.__class__.__name__) @@ -164,8 +151,13 @@ class immutabledict(ImmutableContainer, dict): return immutabledict, (dict(self), ) def union(self, d): - if not self: - return immutabledict(d) + if not d: + return self + elif not self: + if isinstance(d, immutabledict): + return d + else: + return immutabledict(d) else: d2 = immutabledict(self) dict.update(d2, d) @@ -178,8 +170,10 @@ class immutabledict(ImmutableContainer, dict): class Properties(object): """Provide a __getattr__/__setattr__ interface over a dict.""" + __slots__ = '_data', + def __init__(self, data): - self.__dict__['_data'] = data + object.__setattr__(self, '_data', data) def __len__(self): return len(self._data) @@ -199,8 +193,8 @@ class Properties(object): def __delitem__(self, key): del self._data[key] - def __setattr__(self, key, object): - self._data[key] = object + def __setattr__(self, key, obj): + self._data[key] = obj def __getstate__(self): return {'_data': self.__dict__['_data']} @@ -251,6 +245,8 @@ class OrderedProperties(Properties): """Provide a __getattr__/__setattr__ interface with an OrderedDict as backing store.""" + __slots__ = () + def __init__(self): Properties.__init__(self, OrderedDict()) @@ -258,10 +254,17 @@ class OrderedProperties(Properties): class ImmutableProperties(ImmutableContainer, Properties): """Provide immutable dict/object attribute to an underlying dictionary.""" + __slots__ = () + class OrderedDict(dict): """A dict that returns keys/values/items in the order they were added.""" + __slots__ = '_list', + + def __reduce__(self): + return OrderedDict, (self.items(),) + def __init__(self, ____sequence=None, **kwargs): self._list = [] if ____sequence is None: @@ -355,7 +358,10 @@ class OrderedSet(set): set.__init__(self) self._list = [] if d is not None: - self.update(d) + self._list = unique_list(d) + set.update(self, self._list) + else: + self._list = [] def add(self, element): if element not in self: @@ -730,6 +736,12 @@ ordered_column_set = OrderedSet populate_column_dict = PopulateDict +_getters = PopulateDict(operator.itemgetter) + +_property_getters = PopulateDict( + lambda idx: property(operator.itemgetter(idx))) + + def unique_list(seq, hashfunc=None): seen = {} if not hashfunc: @@ -779,10 +791,12 @@ def coerce_generator_arg(arg): def to_list(x, default=None): if x is None: return default - if not isinstance(x, (list, tuple)): + if not isinstance(x, collections.Iterable) or isinstance(x, string_types): return [x] - else: + elif isinstance(x, list): return x + else: + return list(x) def to_set(x): @@ -830,17 +844,30 @@ class LRUCache(dict): """Dictionary with 'squishy' removal of least recently used items. + Note that either get() or [] should be used here, but + generally its not safe to do an "in" check first as the dictionary + can change subsequent to that call. + """ def __init__(self, capacity=100, threshold=.5): self.capacity = capacity self.threshold = threshold self._counter = 0 + self._mutex = threading.Lock() def _inc_counter(self): self._counter += 1 return self._counter + def get(self, key, default=None): + item = dict.get(self, key, default) + if item is not default: + item[2] = self._inc_counter() + return item[1] + else: + return default + def __getitem__(self, key): item = dict.__getitem__(self, key) item[2] = self._inc_counter() @@ -866,18 +893,45 @@ class LRUCache(dict): self._manage_size() def _manage_size(self): - while len(self) > self.capacity + self.capacity * self.threshold: - by_counter = sorted(dict.values(self), - key=operator.itemgetter(2), - reverse=True) - for item in by_counter[self.capacity:]: - try: - del self[item[0]] - except KeyError: - # if we couldn't find a key, most - # likely some other thread broke in - # on us. loop around and try again - break + if not self._mutex.acquire(False): + return + try: + while len(self) > self.capacity + self.capacity * self.threshold: + by_counter = sorted(dict.values(self), + key=operator.itemgetter(2), + reverse=True) + for item in by_counter[self.capacity:]: + try: + del self[item[0]] + except KeyError: + # deleted elsewhere; skip + continue + finally: + self._mutex.release() + + +_lw_tuples = LRUCache(100) + + +def lightweight_named_tuple(name, fields): + hash_ = (name, ) + tuple(fields) + tp_cls = _lw_tuples.get(hash_) + if tp_cls: + return tp_cls + + tp_cls = type( + name, (_LW,), + dict([ + (field, _property_getters[idx]) + for idx, field in enumerate(fields) if field is not None + ] + [('__slots__', ())]) + ) + + tp_cls._real_fields = fields + tp_cls._fields = tuple([f for f in fields if f is not None]) + + _lw_tuples[hash_] = tp_cls + return tp_cls class ScopedRegistry(object): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 972fda667..5b6f691f1 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -1,5 +1,5 @@ # util/compat.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 124f304fc..4c7ea47e3 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -1,5 +1,5 @@ # util/deprecations.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 5c17bea88..3d7bfad0a 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1,5 +1,5 @@ # util/langhelpers.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -92,6 +92,15 @@ def _unique_symbols(used, *bases): raise NameError("exhausted namespace for symbol base %s" % base) +def map_bits(fn, n): + """Call the given function given each nonzero bit from n.""" + + while n: + b = n & (~n + 1) + yield fn(b) + n ^= b + + def decorator(target): """A signature-matching decorator factory.""" @@ -513,6 +522,15 @@ class portable_instancemethod(object): """ + __slots__ = 'target', 'name', '__weakref__' + + def __getstate__(self): + return {'target': self.target, 'name': self.name} + + def __setstate__(self, state): + self.target = state['target'] + self.name = state['name'] + def __init__(self, meth): self.target = meth.__self__ self.name = meth.__name__ @@ -791,6 +809,40 @@ class group_expirable_memoized_property(object): return memoized_instancemethod(fn) +class MemoizedSlots(object): + """Apply memoized items to an object using a __getattr__ scheme. + + This allows the functionality of memoized_property and + memoized_instancemethod to be available to a class using __slots__. + + """ + + def _fallback_getattr(self, key): + raise AttributeError(key) + + def __getattr__(self, key): + if key.startswith('_memoized'): + raise AttributeError(key) + elif hasattr(self, '_memoized_attr_%s' % key): + value = getattr(self, '_memoized_attr_%s' % key)() + setattr(self, key, value) + return value + elif hasattr(self, '_memoized_method_%s' % key): + fn = getattr(self, '_memoized_method_%s' % key) + + def oneshot(*args, **kw): + result = fn(*args, **kw) + memo = lambda *a, **kw: result + memo.__name__ = fn.__name__ + memo.__doc__ = fn.__doc__ + setattr(self, key, memo) + return result + oneshot.__doc__ = fn.__doc__ + return oneshot + else: + return self._fallback_getattr(key) + + def dependency_for(modulename): def decorate(obj): # TODO: would be nice to improve on this import silliness, @@ -936,7 +988,7 @@ def asbool(obj): def bool_or_str(*text): - """Return a callable that will evaulate a string as + """Return a callable that will evaluate a string as boolean, or one of a set of "alternate" string values. """ @@ -969,7 +1021,7 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): kw[key] = type_(kw[key]) -def constructor_copy(obj, cls, **kw): +def constructor_copy(obj, cls, *args, **kw): """Instantiate cls using the __dict__ of obj as constructor arguments. Uses inspect to match the named arguments of ``cls``. @@ -978,7 +1030,7 @@ def constructor_copy(obj, cls, **kw): names = get_cls_kwargs(cls) kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__) - return cls(**kw) + return cls(*args, **kw) def counter(): @@ -1205,9 +1257,12 @@ def warn_exception(func, *args, **kwargs): def ellipses_string(value, len_=25): - if len(value) > len_: - return "%s..." % value[0:len_] - else: + try: + if len(value) > len_: + return "%s..." % value[0:len_] + else: + return value + except TypeError: return value @@ -1296,6 +1351,7 @@ def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): NoneType = type(None) + def attrsetter(attrname): code = \ "def set(obj, value):"\ @@ -1303,3 +1359,29 @@ def attrsetter(attrname): env = locals().copy() exec(code, env) return env['set'] + + +class EnsureKWArgType(type): + """Apply translation of functions to accept **kw arguments if they + don't already. + + """ + def __init__(cls, clsname, bases, clsdict): + fn_reg = cls.ensure_kwarg + if fn_reg: + for key in clsdict: + m = re.match(fn_reg, key) + if m: + fn = clsdict[key] + spec = inspect.getargspec(fn) + if not spec.keywords: + clsdict[key] = wrapped = cls._wrap_w_kw(fn) + setattr(cls, key, wrapped) + super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict) + + def _wrap_w_kw(self, fn): + + def wrap(*arg, **kw): + return fn(*arg) + return update_wrapper(wrap, fn) + diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 796c6a33e..29e00a434 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -1,5 +1,5 @@ # util/queue.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 2bfcccc63..80735c4df 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -1,5 +1,5 @@ # util/topological.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under |