diff options
-rw-r--r-- | CHANGES | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/exc.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 39 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 16 | ||||
-rw-r--r-- | test/sql/labels.py | 9 |
7 files changed, 85 insertions, 27 deletions
@@ -101,11 +101,22 @@ CHANGES a transaction is in progress [ticket:976]. This flag is always True with a "transactional" (in 0.5 a non-"autocommit") Session. - + +- schema + - create_all(), drop_all(), create(), drop() all raise + an error if the table name or schema name contains + more characters than that dialect's configured + character limit. Some DB's can handle too-long + table names during usage, and SQLA can handle this + as well. But various reflection/ + checkfirst-during-create scenarios fail since we are + looking for the name within the DB's catalog tables. + [ticket:571] + - postgres - Repaired server_side_cursors to properly detect text() clauses. - + - mysql - Added 'CALL' to the list of SQL keywords which return result rows. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index b8578151b..dcbf8c76f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -15,7 +15,7 @@ as the base class for their own corresponding classes. import re, random from sqlalchemy.engine import base from sqlalchemy.sql import compiler, expression - +from sqlalchemy import exc AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) @@ -70,7 +70,10 @@ class DefaultDialect(base.Dialect): typeobj = typeobj() return typeobj - + def validate_identifier(self, ident): + if len(ident) > self.max_identifier_length: + raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length)) + def oid_column_name(self, column): return None diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 71b46ca11..e0eb7d88c 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -32,6 +32,8 @@ class CircularDependencyError(SQLAlchemyError): class CompileError(SQLAlchemyError): """Raised when an error occurs during SQL compilation""" +class IdentifierError(SQLAlchemyError): + """Raised when a schema name is beyond the max character limit""" # Moved to orm.exc; compatability definition installed by orm import until 0.6 ConcurrentModificationError = None diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index c5cd0640d..51c57ca30 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1391,8 +1391,7 @@ class Query(object): statement.append_from(from_clause) if context.order_by: - local_adapter = sql_util.ClauseAdapter(inner) - statement.append_order_by(*local_adapter.copy_and_process(context.order_by)) + statement.append_order_by(*context.adapter.copy_and_process(context.order_by)) statement.append_order_by(*context.eager_order_by) else: @@ -1580,7 +1579,14 @@ class _MapperEntity(_QueryEntity): for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic): if query._only_load_props and value.key not in query._only_load_props: continue - value.setup(context, self, (self.path_entity,), adapter, only_load_props=query._only_load_props, column_collection=context.primary_columns) + value.setup( + context, + self, + (self.path_entity,), + adapter, + only_load_props=query._only_load_props, + column_collection=context.primary_columns + ) def __str__(self): return str(self.mapper) @@ -1610,7 +1616,11 @@ class _ColumnEntity(_QueryEntity): self.column = column self.entity_name = None self.froms = util.Set() - self.entities = util.OrderedSet([elem._annotations['parententity'] for elem in visitors.iterate(column, {}) if 'parententity' in elem._annotations]) + self.entities = util.OrderedSet([ + elem._annotations['parententity'] for elem in visitors.iterate(column, {}) + if 'parententity' in elem._annotations + ]) + if self.entities: self.entity_zero = list(self.entities)[0] else: @@ -1620,11 +1630,11 @@ class _ColumnEntity(_QueryEntity): self.selectable = from_obj self.froms.add(from_obj) - def __resolve_expr_against_query_aliases(self, query, expr, context): + def _resolve_expr_against_query_aliases(self, query, expr, context): return query._adapt_clause(expr, False, True) def row_processor(self, query, context, custom_rows): - column = self.__resolve_expr_against_query_aliases(query, self.column, context) + column = self._resolve_expr_against_query_aliases(query, self.column, context) if context.adapter: column = context.adapter.columns[column] @@ -1635,7 +1645,7 @@ class _ColumnEntity(_QueryEntity): return (proc, getattr(column, 'name', None)) def setup_context(self, query, context): - column = self.__resolve_expr_against_query_aliases(query, self.column, context) + column = self._resolve_expr_against_query_aliases(query, self.column, context) context.froms += list(self.froms) context.primary_columns.append(column) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 66e9ccd97..fcb56865b 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -217,7 +217,11 @@ class LoadDeferredColumns(object): self.keys = keys def __getstate__(self): - return {'state':self.state, 'key':self.key, 'keys':self.keys} + return { + 'state':self.state, + 'key':self.key, + 'keys':self.keys + } def __setstate__(self, state): self.state = state['state'] @@ -330,7 +334,7 @@ NoLoader.logger = log.class_logger(NoLoader) class LazyLoader(AbstractRelationLoader): def init(self): super(LazyLoader, self).init() - (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property) + (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self._create_lazy_clause(self.parent_property) self.logger.info("%s lazy loading clause %s" % (self, self.__lazywhere)) @@ -352,7 +356,7 @@ class LazyLoader(AbstractRelationLoader): if not reverse_direction: (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) else: - (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) + (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) def visit_bindparam(bindparam): mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent @@ -371,7 +375,7 @@ class LazyLoader(AbstractRelationLoader): if not reverse_direction: (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) else: - (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) + (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) def visit_binary(binary): mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent @@ -434,7 +438,7 @@ class LazyLoader(AbstractRelationLoader): return (new_execute, None) - def __create_lazy_clause(cls, prop, reverse_direction=False): + def _create_lazy_clause(cls, prop, reverse_direction=False): binds = {} lookup = {} equated_columns = {} @@ -474,7 +478,7 @@ class LazyLoader(AbstractRelationLoader): bind_to_col = dict([(binds[col].key, col) for col in binds]) return (lazywhere, bind_to_col, equated_columns) - __create_lazy_clause = classmethod(__create_lazy_clause) + _create_lazy_clause = classmethod(_create_lazy_clause) LazyLoader.logger = log.class_logger(LazyLoader) @@ -488,7 +492,12 @@ class LoadLazyAttribute(object): self.path = path def __getstate__(self): - return {'state':self.state, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)} + return { + 'state':self.state, + 'key':self.key, + 'options':self.options, + 'path':serialize_path(self.path) + } def __setstate__(self, state): self.state = state['state'] @@ -510,7 +519,11 @@ class LoadLazyAttribute(object): session = sessionlib._state_session(state) if session is None: - raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key)) + raise sa_exc.UnboundExecutionError( + "Parent instance %s is not bound to a Session; " + "lazy load operation of attribute '%s' cannot proceed" % + (mapperutil.state_str(state), self.key) + ) q = session.query(prop.mapper).autoflush(False)._adapt_all_clauses() @@ -547,7 +560,6 @@ class LoadLazyAttribute(object): return result[0] else: return None - class EagerLoader(AbstractRelationLoader): """Loads related objects inline with a parent query.""" @@ -576,8 +588,7 @@ class EagerLoader(AbstractRelationLoader): context.attributes[("eager_row_processor", path)] = clauses = adapter else: - - clauses = self.__create_eager_join(context, entity, path, adapter, parentmapper) + clauses = self._create_eager_join(context, entity, path, adapter, parentmapper) if not clauses: return @@ -586,7 +597,7 @@ class EagerLoader(AbstractRelationLoader): for value in self.mapper._iterate_polymorphic_properties(): value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=context.secondary_columns) - def __create_eager_join(self, context, entity, path, adapter, parentmapper): + def _create_eager_join(self, context, entity, path, adapter, parentmapper): # check for join_depth or basic recursion, # if the current path was not explicitly stated as # a desired "loaderstrategy" (i.e. via query.options()) @@ -662,7 +673,7 @@ class EagerLoader(AbstractRelationLoader): return clauses - def __create_eager_adapter(self, context, row, adapter, path): + def _create_eager_adapter(self, context, row, adapter, path): if ("eager_row_processor", path) in context.attributes: decorator = context.attributes[("eager_row_processor", path)] else: @@ -682,7 +693,7 @@ class EagerLoader(AbstractRelationLoader): def create_row_processor(self, context, path, mapper, row, adapter): path = path + (self.key,) - eager_adapter = self.__create_eager_adapter(context, row, adapter, path) + eager_adapter = self._create_eager_adapter(context, row, adapter, path) if eager_adapter is not False: key = self.key diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8c8374b9a..b57fd3b18 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -753,8 +753,14 @@ class SchemaGenerator(DDLBase): def get_column_specification(self, column, first_pk=False): raise NotImplementedError() + def _can_create(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + def visit_metadata(self, metadata): - collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] + collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if self._can_create(t)] for table in collection: self.traverse_single(table) if self.dialect.supports_alter: @@ -910,13 +916,19 @@ class SchemaDropper(DDLBase): self.dialect = dialect def visit_metadata(self, metadata): - collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] + collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if self._can_drop(t)] if self.dialect.supports_alter: for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: self.traverse_single(table) + def _can_drop(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) + def visit_index(self, index): self.append("\nDROP INDEX " + self.preparer.format_index(index)) self.execute() diff --git a/test/sql/labels.py b/test/sql/labels.py index 78b31adc4..3e025e5e7 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -1,5 +1,6 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exc as exceptions from testlib import * from sqlalchemy.engine import default @@ -38,6 +39,14 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): metadata.drop_all() testing.db.dialect.max_identifier_length = maxlen + def test_too_long_name_disallowed(self): + m = MetaData(testing.db) + t1 = Table("this_name_is_too_long_for_what_were_doing_in_this_test", m, Column('foo', Integer)) + self.assertRaises(exceptions.IdentifierError, m.create_all) + self.assertRaises(exceptions.IdentifierError, m.drop_all) + self.assertRaises(exceptions.IdentifierError, t1.create) + self.assertRaises(exceptions.IdentifierError, t1.drop) + def test_result(self): table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"}) table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"}) |