diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-20 21:50:59 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-20 21:50:59 +0000 |
commit | 109d5359617bf8c9a8acc6498935a22f9f6949ef (patch) | |
tree | be3bbeb9543852da0e1c9d27ca67876b2a25cf80 | |
parent | 531faf0e187d756bda92a937a77accd86b813339 (diff) | |
download | sqlalchemy-109d5359617bf8c9a8acc6498935a22f9f6949ef.tar.gz |
- method call removal
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/informix.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 33 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 46 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 59 | ||||
-rw-r--r-- | lib/sqlalchemy/schema.py | 41 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 98 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 40 | ||||
-rw-r--r-- | lib/sqlalchemy/util.py | 6 | ||||
-rw-r--r-- | test/ext/activemapper.py | 2 | ||||
-rw-r--r-- | test/orm/unitofwork.py | 4 | ||||
-rw-r--r-- | test/sql/labels.py | 4 | ||||
-rw-r--r-- | test/sql/rowcount.py | 6 |
18 files changed, 159 insertions, 259 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index f901ebf53..4994e3309 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -6,7 +6,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php import random -from sqlalchemy import sql, schema, ansisql, types, exceptions, pool +from sqlalchemy import sql, schema, types, exceptions, pool from sqlalchemy.sql import compiler import sqlalchemy.engine.default as default @@ -159,7 +159,7 @@ class AccessExecutionContext(default.DefaultExecutionContext): const, daoEngine = None, None -class AccessDialect(ansisql.ANSIDialect): +class AccessDialect(compiler.DefaultDialect): colspecs = { types.Unicode : AcUnicode, types.Integer : AcInteger, @@ -176,6 +176,9 @@ class AccessDialect(ansisql.ANSIDialect): types.TIMESTAMP: AcTimeStamp, } + supports_sane_rowcount = False + + def type_descriptor(self, typeobj): newobj = types.adapt_type(typeobj, self.colspecs) return newobj @@ -211,9 +214,6 @@ class AccessDialect(ansisql.ANSIDialect): def create_execution_context(self, *args, **kwargs): return AccessExecutionContext(self, *args, **kwargs) - def supports_sane_rowcount(self): - return False - def last_inserted_ids(self): return self.context.last_inserted_ids @@ -416,7 +416,7 @@ class AccessSchemaDropper(compiler.SchemaDropper): self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name)) self.execute() -class AccessDefaultRunner(ansisql.ANSIDefaultRunner): +class AccessDefaultRunner(compiler.DefaultRunner): pass class AccessIdentifierPreparer(compiler.IdentifierPreparer): diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 9cccb53e8..2a9bbb5bd 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -101,6 +101,9 @@ class FBExecutionContext(default.DefaultExecutionContext): class FBDialect(default.DefaultDialect): + supports_sane_rowcount = False + max_identifier_length = 31 + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): default.DefaultDialect.__init__(self, **kwargs) @@ -133,12 +136,6 @@ class FBDialect(default.DefaultDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def supports_sane_rowcount(self): - return False - - def max_identifier_length(self): - return 31 - def table_names(self, connection, schema): s = "SELECT R.RDB$RELATION_NAME FROM RDB$RELATIONS R" return [row[0] for row in connection.execute(s)] @@ -408,12 +405,11 @@ RESERVED_WORDS = util.Set( class FBIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + def __init__(self, dialect): super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True) - def _reserved_words(self): - return RESERVED_WORDS - dialect = FBDialect dialect.statement_compiler = FBCompiler diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 45dfb0370..67d31387d 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -205,6 +205,8 @@ class InfoExecutionContext(default.DefaultExecutionContext): return informix_cursor( self.connection.connection ) class InfoDialect(default.DefaultDialect): + # for informix 7.31 + max_identifier_length = 18 def __init__(self, use_ansi=True,**kwargs): self.use_ansi = use_ansi @@ -216,10 +218,6 @@ class InfoDialect(default.DefaultDialect): return informixdb dbapi = classmethod(dbapi) - def max_identifier_length( self ): - # for informix 7.31 - return 18 - def is_disconnect(self, e): if isinstance(e, self.dbapi.OperationalError): return 'closed the connection' in str(e) or 'connection not open' in str(e) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 1985d2112..03b276d4a 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -472,10 +472,6 @@ class MSSQLDialect(default.DefaultDialect): def last_inserted_ids(self): return self.context.last_inserted_ids - # this is only implemented in the dbapi-specific subclasses - def supports_sane_rowcount(self): - raise NotImplementedError() - def get_default_schema_name(self, connection): return self.schema_name @@ -665,6 +661,8 @@ class MSSQLDialect(default.DefaultDialect): table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) class MSSQLDialect_pymssql(MSSQLDialect): + supports_sane_rowcount = False + def import_dbapi(cls): import pymssql as module # pymmsql doesn't have a Binary method. we use string @@ -683,12 +681,6 @@ class MSSQLDialect_pymssql(MSSQLDialect): super(MSSQLDialect_pymssql, self).__init__(**params) self.use_scope_identity = True - def supports_sane_rowcount(self): - return False - - def max_identifier_length(self): - return 30 - def do_rollback(self, connection): # pymssql throws an error on repeated rollbacks. Ignore it. # TODO: this is normal behavior for most DBs. are we sure we want to ignore it ? @@ -746,6 +738,9 @@ class MSSQLDialect_pymssql(MSSQLDialect): ## r.fetch_array() class MSSQLDialect_pyodbc(MSSQLDialect): + supports_sane_rowcount = False + # PyODBC unicode is broken on UCS-4 builds + supports_unicode_statements = sys.maxunicode == 65535 def __init__(self, **params): super(MSSQLDialect_pyodbc, self).__init__(**params) @@ -771,14 +766,6 @@ class MSSQLDialect_pyodbc(MSSQLDialect): ischema_names['smalldatetime'] = MSDate_pyodbc ischema_names['datetime'] = MSDateTime_pyodbc - def supports_sane_rowcount(self): - return False - - def supports_unicode_statements(self): - """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" - # PyODBC unicode is broken on UCS-4 builds - return sys.maxunicode == 65535 - def make_connect_string(self, keys): if 'dsn' in keys: connectors = ['dsn=%s' % keys['dsn']] @@ -818,6 +805,9 @@ class MSSQLDialect_pyodbc(MSSQLDialect): context._last_inserted_ids = [int(row[0])] class MSSQLDialect_adodbapi(MSSQLDialect): + supports_sane_rowcount = True + supports_unicode_statements = True + def import_dbapi(cls): import adodbapi as module return module @@ -831,13 +821,6 @@ class MSSQLDialect_adodbapi(MSSQLDialect): ischema_names['nvarchar'] = AdoMSNVarchar ischema_names['datetime'] = MSDateTime_adodbapi - def supports_sane_rowcount(self): - return True - - def supports_unicode_statements(self): - """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" - return True - def make_connect_string(self, keys): connectors = ["Provider=SQLOLEDB"] if 'port' in keys: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 41c6ec70f..6dc0d6057 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1332,6 +1332,12 @@ class MySQLExecutionContext(default.DefaultExecutionContext): class MySQLDialect(default.DefaultDialect): """Details of the MySQL dialect. Not used directly in application code.""" + supports_alter = True + supports_unicode_statements = False + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + supports_sane_rowcount = True + def __init__(self, use_ansiquotes=False, **kwargs): self.use_ansiquotes = use_ansiquotes kwargs.setdefault('default_paramstyle', 'format') @@ -1390,13 +1396,6 @@ class MySQLDialect(default.DefaultDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - # identifiers are 64, however aliases can be 255... - def max_identifier_length(self): - return 255; - - def supports_sane_rowcount(self): - return True - def compiler(self, statement, bindparams, **kwargs): return MySQLCompiler(statement, bindparams, dialect=self, **kwargs) @@ -2369,13 +2368,12 @@ MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector) class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): """MySQL-specific schema identifier configuration.""" + + reserved_words = RESERVED_WORDS def __init__(self, dialect, **kw): super(_MySQLIdentifierPreparer, self).__init__(dialect, **kw) - def _reserved_words(self): - return RESERVED_WORDS - def _fold_identifier_case(self, value): # TODO: determine MySQL's case folding rules # diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 580850818..9b3ffbf23 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -232,6 +232,11 @@ class OracleExecutionContext(default.DefaultExecutionContext): return base.ResultProxy(self) class OracleDialect(default.DefaultDialect): + supports_alter = True + supports_unicode_statements = False + max_identifier_length = 30 + supports_sane_rowcount = True + def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs): default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs) self.use_ansi = use_ansi @@ -291,13 +296,6 @@ class OracleDialect(default.DefaultDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def supports_unicode_statements(self): - """indicate whether the DB-API can receive SQL statements as Python unicode strings""" - return False - - def max_identifier_length(self): - return 30 - def oid_column_name(self, column): if not isinstance(column.table, (sql.TableClause, sql.Select)): return None diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 29d84ad4d..2a4d230cd 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -223,6 +223,11 @@ class PGExecutionContext(default.DefaultExecutionContext): super(PGExecutionContext, self).post_exec() class PGDialect(default.DefaultDialect): + supports_alter = True + supports_unicode_statements = False + max_identifier_length = 63 + supports_sane_rowcount = True + def __init__(self, use_oids=False, server_side_cursors=False, **kwargs): default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids @@ -241,13 +246,9 @@ class PGDialect(default.DefaultDialect): opts.update(url.query) return ([], opts) - def create_execution_context(self, *args, **kwargs): return PGExecutionContext(self, *args, **kwargs) - def max_identifier_length(self): - return 63 - def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index c2aced4d0..8618bfc3e 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -174,6 +174,8 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): return SELECT_REGEXP.match(self.statement) class SQLiteDialect(default.DefaultDialect): + supports_alter = False + supports_unicode_statements = True def __init__(self, **kwargs): default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs) @@ -199,9 +201,6 @@ class SQLiteDialect(default.DefaultDialect): def server_version_info(self, connection): return self.dbapi.sqlite_version_info - def supports_alter(self): - return False - def create_connect_args(self, url): filename = url.database or ':memory:' @@ -220,9 +219,6 @@ class SQLiteDialect(default.DefaultDialect): def create_execution_context(self, **kwargs): return SQLiteExecutionContext(self, **kwargs) - def supports_unicode_statements(self): - return True - def last_inserted_ids(self): return self.context.last_inserted_ids diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 2e75d358c..ef875a638 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -62,6 +62,19 @@ class Dialect(object): preparer a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote identifiers. + + supports_alter + ``True`` if the database supports ``ALTER TABLE``. + + max_identifier_length + The maximum length of identifier names. + + supports_unicode_statements + Indicate whether the DB-API can receive SQL statements as Python unicode strings + + supports_sane_rowcount + Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. + """ def create_connect_args(self, url): @@ -119,31 +132,6 @@ class Dialect(object): raise NotImplementedError() - def supports_alter(self): - """Return ``True`` if the database supports ``ALTER TABLE``.""" - raise NotImplementedError() - - def max_identifier_length(self): - """Return the maximum length of identifier names. - - Returns ``None`` if no limit. - """ - - return None - - def supports_unicode_statements(self): - """Indicate whether the DB-API can receive SQL statements as Python unicode strings""" - - raise NotImplementedError() - - def supports_sane_rowcount(self): - """Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. - - This was needed for MySQL which had non-standard behavior of rowcount, - but this issue has since been resolved. - """ - - raise NotImplementedError() def server_version_info(self, connection): @@ -521,9 +509,6 @@ class Connectable(object): def execute(self, object, *multiparams, **params): raise NotImplementedError() - engine = util.NotImplProperty("The Engine which this Connectable is associated with.") - dialect = util.NotImplProperty("Dialect which this Connectable is associated with.") - class Connection(Connectable): """Provides high-level functionality for a wrapped DB-API connection. @@ -1020,14 +1005,13 @@ class Engine(Connectable): def __init__(self, pool, dialect, url, echo=None): self.pool = pool self.url = url - self._dialect=dialect + self.dialect=dialect self.echo = echo + self.engine = self self.logger = logging.instance_logger(self) self._should_log = logging.is_info_enabled(self.logger) name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'], doc="String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``.") - engine = property(lambda s:s) - dialect = property(lambda s:s._dialect, doc="the [sqlalchemy.engine#Dialect] in use by this engine.") echo = logging.echo_property() def __repr__(self): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ec8d8d5a7..50fac430b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -26,6 +26,10 @@ class DefaultDialect(base.Dialect): statement_compiler = compiler.DefaultCompiler preparer = compiler.IdentifierPreparer defaultrunner = base.DefaultRunner + supports_alter = True + supports_unicode_statements = False + max_identifier_length = 9999 + supports_sane_rowcount = True def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode @@ -33,7 +37,13 @@ class DefaultDialect(base.Dialect): self.positional = False self._ischema = None self.dbapi = dbapi - self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) + if paramstyle is not None: + self.paramstyle = paramstyle + elif self.dbapi is not None: + self.paramstyle = self.dbapi.paramstyle + else: + self.paramstyle = default_paramstyle + self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) def dbapi_type_map(self): @@ -56,23 +66,10 @@ class DefaultDialect(base.Dialect): typeobj = typeobj() return typeobj - def supports_unicode_statements(self): - """True if DB-API can receive SQL statements as Python Unicode.""" - return False - - def max_identifier_length(self): - # TODO: probably raise this and fill out db modules better - return 9999 - - def supports_alter(self): - return True def oid_column_name(self, column): return None - def supports_sane_rowcount(self): - return True - def do_begin(self, connection): """Implementations might want to put logic here for turning autocommit on/off, etc. @@ -120,32 +117,6 @@ class DefaultDialect(base.Dialect): def is_disconnect(self, e): return False - def _set_paramstyle(self, style): - self._paramstyle = style - self._figure_paramstyle(style) - - paramstyle = property(lambda s:s._paramstyle, _set_paramstyle) - - def _figure_paramstyle(self, paramstyle=None, default='named'): - if paramstyle is not None: - self._paramstyle = paramstyle - elif self.dbapi is not None: - self._paramstyle = self.dbapi.paramstyle - else: - self._paramstyle = default - - if self._paramstyle == 'named': - self.positional=False - elif self._paramstyle == 'pyformat': - self.positional=False - elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric': - # for positional, use pyformat internally, ANSICompiler will convert - # to appropriate character upon compilation - self.positional = True - else: - raise exceptions.InvalidRequestError( - "Unsupported paramstyle '%s'" % self._paramstyle) - def _get_ischema(self): if self._ischema is None: import sqlalchemy.databases.information_schema as ischema @@ -185,7 +156,7 @@ class DefaultExecutionContext(base.ExecutionContext): else: self.statement = None - if self.statement is not None and not dialect.supports_unicode_statements(): + if self.statement is not None and not dialect.supports_unicode_statements: self.statement = self.statement.encode(self.dialect.encoding) self.cursor = self.create_cursor() @@ -200,7 +171,7 @@ class DefaultExecutionContext(base.ExecutionContext): def __encode_param_keys(self, params): """apply string encoding to the keys of dictionary-based bind parameters""" - if self.dialect.positional or self.dialect.supports_unicode_statements(): + if self.dialect.positional or self.dialect.supports_unicode_statements: return params else: def proc(d): @@ -215,7 +186,7 @@ class DefaultExecutionContext(base.ExecutionContext): return proc(params) def __convert_compiled_params(self, parameters): - encode = not self.dialect.supports_unicode_statements() + encode = not self.dialect.supports_unicode_statements # the bind params are a CompiledParams object. but all the # DB-API's hate that object (or similar). so convert it to a # clean dictionary/list/tuple of dictionary/tuple of list @@ -274,7 +245,7 @@ class DefaultExecutionContext(base.ExecutionContext): return self.cursor.rowcount def supports_sane_rowcount(self): - return self.dialect.supports_sane_rowcount() + return self.dialect.supports_sane_rowcount def last_inserted_ids(self): return self._last_inserted_ids diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 99ca2389b..b6f345be2 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -58,27 +58,21 @@ class SchemaItem(object): def __repr__(self): return "%s()" % self.__class__.__name__ - def _derived_metadata(self): - """Return the the MetaData to which this item is bound.""" - - return None - def _get_bind(self, raiseerr=False): """Return the engine or None if no engine.""" if raiseerr: - m = self._derived_metadata() + m = self.metadata e = m and m.bind or None if e is None: raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") else: return e else: - m = self._derived_metadata() + m = self.metadata return m and m.bind or None - metadata = property(lambda s:s._derived_metadata()) bind = property(lambda s:s._get_bind()) def _get_table_key(name, schema): @@ -228,7 +222,7 @@ class Table(SchemaItem, expression.TableClause): """ super(Table, self).__init__(name) - self._metadata = metadata + self.metadata = metadata self.schema = kwargs.pop('schema', None) self.indexes = util.Set() self.constraints = util.Set() @@ -263,9 +257,6 @@ class Table(SchemaItem, expression.TableClause): self.constraints.add(pk) primary_key = property(lambda s:s._primary_key, _set_primary_key) - def _derived_metadata(self): - return self._metadata - def __repr__(self): return "Table(%s)" % ', '.join( [repr(self.name)] + [repr(self.metadata)] + @@ -286,11 +277,11 @@ class Table(SchemaItem, expression.TableClause): constraint._set_parent(self) def _get_parent(self): - return self._metadata + return self.metadata def _set_parent(self, metadata): metadata.tables[_get_table_key(self.name, self.schema)] = self - self._metadata = metadata + self.metadata = metadata def get_children(self, column_collections=True, schema_visitor=False, **kwargs): if not schema_visitor: @@ -476,9 +467,6 @@ class Column(SchemaItem, expression._ColumnClause): else: return self.encodedname - def _derived_metadata(self): - return self.table.metadata - def _get_bind(self): return self.table.bind @@ -515,6 +503,7 @@ class Column(SchemaItem, expression._ColumnClause): return self.table def _set_parent(self, table): + self.metadata = table.metadata if getattr(self, 'table', None) is not None: raise exceptions.ArgumentError("this Column already has a table!") if not self._is_oid: @@ -699,20 +688,14 @@ class DefaultGenerator(SchemaItem): def __init__(self, for_update=False, metadata=None): self.for_update = for_update - self._metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') - - def _derived_metadata(self): - try: - return self.column.table.metadata - except AttributeError: - return self._metadata + self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') def _get_parent(self): return getattr(self, 'column', None) def _set_parent(self, column): self.column = column - self._metadata = self.column.table.metadata + self.metadata = self.column.table.metadata if self.for_update: self.column.onupdate = self else: @@ -957,9 +940,6 @@ class Index(SchemaItem): self.unique = kwargs.pop('unique', False) self._init_items(*columns) - def _derived_metadata(self): - return self.table.metadata - def _init_items(self, *args): for column in args: self.append_column(column) @@ -969,6 +949,7 @@ class Index(SchemaItem): def _set_parent(self, table): self.table = table + self.metadata = table.metadata table.indexes.add(self) def append_column(self, column): @@ -1053,6 +1034,7 @@ class MetaData(SchemaItem): self.tables = {} self.bind = bind + self.metadata = self if reflect: if not bind: raise exceptions.ArgumentError( @@ -1239,9 +1221,6 @@ class MetaData(SchemaItem): bind = self._get_bind(raiseerr=True) bind.drop(self, checkfirst=checkfirst, tables=tables) - def _derived_metadata(self): - return self - def _get_bind(self, raiseerr=False): if not self.is_bound(): if raiseerr: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 59eb3cdb3..59964178c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -421,9 +421,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) - if len(anonname) > self.dialect.max_identifier_length(): + if len(anonname) > self.dialect.max_identifier_length: counter = self.generated_ids.get(ident_class, 1) - truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:] + truncname = name[0:self.dialect.max_identifier_length - 6] + "_" + hex(counter)[2:] self.generated_ids[ident_class] = counter + 1 else: truncname = anonname @@ -515,7 +515,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): l = co.label(labelname) inner_columns.add(self.process(l)) else: - self.traverse(co) inner_columns.add(self.process(co)) else: l = self.label_select_column(select, co) @@ -620,20 +619,16 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): # for inserts, this includes Python-side defaults, columns with sequences for dialects # that support sequences, and primary key columns for dialects that explicitly insert # pre-generated primary key values - required_cols = util.Set() - class DefaultVisitor(schema.SchemaVisitor): - def visit_column(s, cd): - if c.primary_key and self.uses_sequences_for_inserts(): - required_cols.add(c) - def visit_column_default(s, cd): - required_cols.add(c) - def visit_sequence(s, seq): - if self.uses_sequences_for_inserts(): - required_cols.add(c) - vis = DefaultVisitor() - for c in insert_stmt.table.c: - if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - vis.traverse(c) + required_cols = [ + c for c in insert_stmt.table.c + if \ + isinstance(c, schema.SchemaItem) and \ + (self.parameters is None or self.parameters.get(c.key, None) is None) and \ + ( + ((c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts()) or + isinstance(c.default, schema.ColumnDefault) + ) + ] self.isinsert = True colparams = self._get_colparams(insert_stmt, required_cols) @@ -646,14 +641,12 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): # search for columns who will be required to have an explicit bound value. # for updates, this includes Python-side "onupdate" defaults. - required_cols = util.Set() - class OnUpdateVisitor(schema.SchemaVisitor): - def visit_column_onupdate(s, cd): - required_cols.add(c) - vis = OnUpdateVisitor() - for c in update_stmt.table.c: - if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - vis.traverse(c) + required_cols = [c for c in update_stmt.table.c + if + isinstance(c, schema.SchemaItem) and \ + (self.parameters is None or self.parameters.get(c.key, None) is None) and + isinstance(c.onupdate, schema.ColumnDefault) + ] self.isupdate = True colparams = self._get_colparams(update_stmt, required_cols) @@ -681,11 +674,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) - def create_clause_param(col, value): - self.traverse(value) - self.inline_params.add(col) - return self.process(value) - self.inline_params = util.Set() def to_col(key): @@ -704,25 +692,28 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): if self.parameters is None: parameters = {} else: - parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()]) + parameters = dict([(getattr(k, 'key', k), v) for k, v in self.parameters.iteritems()]) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): - parameters.setdefault(to_col(k), v) + parameters.setdefault(getattr(k, 'key', k), v) for col in required_cols: - parameters.setdefault(col, None) + parameters.setdefault(col.key, None) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: - if c in parameters: - value = parameters[c] - if sql._is_literal(value): - value = create_bind_param(c, value) - else: - value = create_clause_param(c, value) - values.append((c, value)) + if c.key in parameters: + value = parameters[c.key] + else: + continue + if sql._is_literal(value): + value = create_bind_param(c, value) + else: + self.inline_params.add(c) + value = self.process(value) + values.append((c, value)) return values @@ -778,7 +769,7 @@ class SchemaGenerator(DDLBase): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: self.traverse_single(table) - if self.dialect.supports_alter(): + if self.dialect.supports_alter: for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -853,7 +844,7 @@ class SchemaGenerator(DDLBase): self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.dialect.supports_alter(): + if constraint.use_alter and self.dialect.supports_alter: return self.append(", \n\t ") self.define_foreign_key(constraint) @@ -909,7 +900,7 @@ class SchemaDropper(DDLBase): def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] - if self.dialect.supports_alter(): + if self.dialect.supports_alter: for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: @@ -936,6 +927,12 @@ class SchemaDropper(DDLBase): class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): """Construct a new ``IdentifierPreparer`` object. @@ -995,21 +992,12 @@ class IdentifierPreparer(object): # some tests would need to be rewritten if this is done. #return value.upper() - def _reserved_words(self): - return RESERVED_WORDS - - def _legal_characters(self): - return LEGAL_CHARACTERS - - def _illegal_initial_characters(self): - return ILLEGAL_INITIAL_CHARACTERS - def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" return \ - value in self._reserved_words() \ - or (value[0] in self._illegal_initial_characters()) \ - or bool(len([x for x in unicode(value) if x not in self._legal_characters()])) \ + value in self.reserved_words \ + or (value[0] in self.illegal_initial_characters) \ + or bool(len([x for x in unicode(value) if x not in self.legal_characters])) \ or (value.lower() != value) def __generic_obj_format(self, obj, ident): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 64fc9b3b4..ea87a8c4f 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1007,7 +1007,7 @@ class ClauseElement(object): compiler = dialect.statement_compiler(dialect, self, parameters=parameters) compiler.compile() return compiler - + def __str__(self): return unicode(self.compile()).encode('ascii', 'backslashreplace') @@ -1618,13 +1618,6 @@ class FromClause(Selectable): else: raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.name)) - def _get_exported_attribute(self, name): - try: - return getattr(self, name) - except AttributeError: - self._export_columns() - return getattr(self, name) - def _clone_from_clause(self): # delete all the "generated" collections of columns for a # newly cloned FromClause, so that they will be re-derived @@ -1635,11 +1628,20 @@ class FromClause(Selectable): if hasattr(self, attr): delattr(self, attr) - columns = property(lambda s:s._get_exported_attribute('_columns')) - c = property(lambda s:s._get_exported_attribute('_columns')) - primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) - foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys')) - original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'), doc=\ + def _expr_attr_func(name): + def attr(self): + try: + return getattr(self, name) + except AttributeError: + self._export_columns() + return getattr(self, name) + return attr + + columns = property(_expr_attr_func('_columns')) + c = property(_expr_attr_func('_columns')) + primary_key = property(_expr_attr_func('_primary_key')) + foreign_keys = property(_expr_attr_func('_foreign_keys')) + original_columns = property(_expr_attr_func('_orig_cols'), doc=\ """A dictionary mapping an original Table-bound column to a proxied column in this FromClause. """) @@ -1659,7 +1661,6 @@ class FromClause(Selectable): """ if hasattr(self, '_columns') and columns is None: - # TODO: put a mutex here ? this is a key place for threading probs return self._columns = ColumnCollection() self._primary_key = ColumnSet() @@ -1753,9 +1754,11 @@ class _BindParamClause(ClauseElement, _CompareMixin): self.shortname = shortname or key self.unique = unique self.isoutparam = isoutparam - type_ = sqltypes.to_instance(type_) - if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map: - self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)]) + + if type_ is None: + self.type = self.type_map.get(type(value), sqltypes.NullType)() + elif isinstance(type_, type): + self.type = type_() else: self.type = type_ @@ -1764,7 +1767,8 @@ class _BindParamClause(ClauseElement, _CompareMixin): str : sqltypes.String, unicode : sqltypes.Unicode, int : sqltypes.Integer, - float : sqltypes.Numeric + float : sqltypes.Numeric, + type(None):sqltypes.NullType } def _get_from_objects(self, **modifiers): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index ba6458f2a..d31be6a36 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -268,7 +268,11 @@ class OrderedProperties(object): def __setattr__(self, key, object): self._data[key] = object - _data = property(lambda s:s.__dict__['_data']) + def __getstate__(self): + return self._data + + def __setstate__(self, value): + self.__dict__['_data'] = value def __getattr__(self, key): try: diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py index e28c72cd7..7e266030c 100644 --- a/test/ext/activemapper.py +++ b/test/ext/activemapper.py @@ -175,7 +175,7 @@ class testcase(PersistTest): objectstore.context.current = s1 objectstore.flush() # Only dialects with a sane rowcount can detect the ConcurrentModificationError - if testbase.db.dialect.supports_sane_rowcount(): + if testbase.db.dialect.supports_sane_rowcount: assert False except exceptions.ConcurrentModificationError: pass diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index fd7af0421..d689f1703 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -78,7 +78,7 @@ class VersioningTest(ORMTest): success = True # Only dialects with a sane rowcount can detect the ConcurrentModificationError - if testbase.db.dialect.supports_sane_rowcount(): + if testbase.db.dialect.supports_sane_rowcount: assert success s.close() @@ -96,7 +96,7 @@ class VersioningTest(ORMTest): except exceptions.ConcurrentModificationError, e: #print e success = True - if testbase.db.dialect.supports_sane_rowcount(): + if testbase.db.dialect.supports_sane_rowcount: assert success @engines.close_open_connections diff --git a/test/sql/labels.py b/test/sql/labels.py index dee76428d..6588c4da4 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -27,7 +27,7 @@ class LongLabelsTest(PersistTest): metadata.create_all() maxlen = testbase.db.dialect.max_identifier_length - testbase.db.dialect.max_identifier_length = lambda: 29 + testbase.db.dialect.max_identifier_length = 29 def tearDown(self): table1.delete().execute() @@ -89,7 +89,7 @@ class LongLabelsTest(PersistTest): """test that a primary key column compiled as the 'oid' column gets proper length truncation""" from sqlalchemy.databases import postgres dialect = postgres.PGDialect() - dialect.max_identifier_length = lambda: 30 + dialect.max_identifier_length = 30 tt = table1.select(use_labels=True).alias('foo') x = select([tt], use_labels=True, order_by=tt.oid_column).compile(dialect=dialect) #print x diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py index cf9ba30d9..095f79200 100644 --- a/test/sql/rowcount.py +++ b/test/sql/rowcount.py @@ -47,21 +47,21 @@ class FoundRowsTest(AssertMixin): # WHERE matches 3, 3 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='Z') - if testbase.db.dialect.supports_sane_rowcount(): + if testbase.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 def test_update_rowcount2(self): # WHERE matches 3, 0 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='C') - if testbase.db.dialect.supports_sane_rowcount(): + if testbase.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 def test_delete_rowcount(self): # WHERE matches 3, 3 rows deleted department = employees_table.c.department r = employees_table.delete(department=='C').execute() - if testbase.db.dialect.supports_sane_rowcount(): + if testbase.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 if __name__ == '__main__': |