diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-10-14 21:58:04 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-10-14 21:58:04 +0000 |
commit | 8340006dd7ed34cf32bbb7f856397d1c7f13d295 (patch) | |
tree | 3429fe31b379b2ccc10e6653e33d4d6d23fd5ae4 /lib/sqlalchemy | |
parent | 5bb47440e03bb6ac0d3bd92eab4a6d69304ff556 (diff) | |
download | sqlalchemy-8340006dd7ed34cf32bbb7f856397d1c7f13d295.tar.gz |
- a fair amount of cleanup to the schema package, removal of ambiguous
methods, methods that are no longer needed. slightly more constrained
useage, greater emphasis on explicitness.
- table_iterator signature fixup, includes fix for [ticket:288]
- the "primary_key" attribute of Table and other selectables becomes
a setlike ColumnCollection object; is no longer ordered or numerically
indexed. a comparison clause between two pks that are derived from the
same underlying tables (i.e. such as two Alias objects) can be generated
via table1.primary_key==table2.primary_key
- append_item() methods removed from Table and Column; preferably
construct Table/Column/related objects inline, but if needed use
append_column(), append_foreign_key(), append_constraint(), etc.
- table.create() no longer returns the Table object, instead has no
return value. the usual case is that tables are created via metadata,
which is preferable since it will handle table dependencies.
- added UniqueConstraint (goes at Table level), CheckConstraint
(goes at Table or Column level) fixes [ticket:217]
- index=False/unique=True on Column now creates a UniqueConstraint,
index=True/unique=False creates a plain Index,
index=True/unique=True on Column creates a unique Index. 'index'
and 'unique' keyword arguments to column are now boolean only; for
explcit names and groupings of indexes or unique constraints, use the
UniqueConstraint/Index constructs explicitly.
- relationship of Metadata/Table/SchemaGenerator/Dropper has been
improved so that the schemavisitor receives the metadata object
for greater control over groupings of creates/drops.
- added "use_alter" argument to ForeignKey, ForeignKeyConstraint,
but it doesnt do anything yet. will utilize new generator/dropper
behavior to implement.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 79 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/information_schema.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/schema.py | 417 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 35 | ||||
-rw-r--r-- | lib/sqlalchemy/util.py | 2 |
13 files changed, 277 insertions, 309 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 2b0d7d17e..208b2f603 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -7,7 +7,7 @@ """defines ANSI SQL operations. Contains default implementations for the abstract objects in the sql module.""" -from sqlalchemy import schema, sql, engine, util +from sqlalchemy import schema, sql, engine, util, sql_util import sqlalchemy.engine.default as default import string, re, sets, weakref @@ -28,9 +28,6 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', ' LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') -def create_engine(): - return engine.ComposedSQLEngine(None, ANSIDialect()) - class ANSIDialect(default.DefaultDialect): def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) @@ -174,7 +171,7 @@ class ANSICompiler(sql.Compiled): if n is not None: self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n) elif len(column.table.primary_key) != 0: - self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0]) + self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0]) else: self.strings[column] = None else: @@ -611,22 +608,30 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(engine.SchemaIterator): - def __init__(self, engine, proxy, connection=None, checkfirst=False, **params): - super(ANSISchemaGenerator, self).__init__(engine, proxy, **params) + def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst + self.tables = tables and util.Set(tables) or None self.connection = connection self.preparer = self.engine.dialect.preparer() - + self.dialect = self.engine.dialect + def get_column_specification(self, column, first_pk=False): raise NotImplementedError() - - def visit_table(self, table): - # the single whitespace before the "(" is significant - # as its MySQL's method of indicating a table name and not a reserved word. - # feel free to localize this logic to the mysql module - if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name): - return + + def visit_metadata(self, metadata): + for table in metadata.table_iterator(reverse=False, tables=self.tables): + if self.checkfirst and self.dialect.has_table(self.connection, table.name): + continue + table.accept_schema_visitor(self, traverse=False) + def visit_table(self, table): + for column in table.columns: + if column.default is not None: + column.default.accept_schema_visitor(self, traverse=False) + #if column.onupdate is not None: + # column.onupdate.accept_schema_visitor(visitor, traverse=False) + self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") separator = "\n" @@ -639,15 +644,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) if column.primary_key: first_pk = True - + for constraint in column.constraints: + constraint.accept_schema_visitor(self, traverse=False) + for constraint in table.constraints: - constraint.accept_schema_visitor(self) + constraint.accept_schema_visitor(self, traverse=False) self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() + self.execute() if hasattr(table, 'indexes'): for index in table.indexes: - self.visit_index(index) + index.accept_schema_visitor(self, traverse=False) def post_create_table(self, table): return '' @@ -662,10 +669,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): return None def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" compiler = self.engine.dialect.compiler(tocompile, parameters) compiler.compile() return compiler + def visit_check_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append(" CHECK (%s)" % constraint.sqltext) + def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: return @@ -688,6 +702,13 @@ class ANSISchemaGenerator(engine.SchemaIterator): if constraint.onupdate is not None: self.append(" ON UPDATE %s" % constraint.onupdate) + def visit_unique_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append(" UNIQUE ") + self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + def visit_column(self, column): pass @@ -701,21 +722,29 @@ class ANSISchemaGenerator(engine.SchemaIterator): self.execute() class ANSISchemaDropper(engine.SchemaIterator): - def __init__(self, engine, proxy, connection=None, checkfirst=False, **params): - super(ANSISchemaDropper, self).__init__(engine, proxy, **params) + def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst + self.tables = tables self.connection = connection self.preparer = self.engine.dialect.preparer() + self.dialect = self.engine.dialect + + def visit_metadata(self, metadata): + for table in metadata.table_iterator(reverse=True, tables=self.tables): + if self.checkfirst and not self.dialect.has_table(self.connection, table.name): + continue + table.accept_schema_visitor(self, traverse=False) def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() def visit_table(self, table): - # NOTE: indexes on the table will be automatically dropped, so - # no need to drop them individually - if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name): - return + for column in table.columns: + if column.default is not None: + column.default.accept_schema_visitor(self, traverse=False) + self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index fa090a89e..f38a24b1f 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -253,7 +253,7 @@ class FireBirdDialect(ansisql.ANSIDialect): # is it a primary key? kw['primary_key'] = name in pkfields - table.append_item(schema.Column(*args, **kw)) + table.append_column(schema.Column(*args, **kw)) row = c.fetchone() # get the foreign keys @@ -276,7 +276,7 @@ class FireBirdDialect(ansisql.ANSIDialect): fk[1].append(refspec) for name,value in fks.iteritems(): - table.append_item(schema.ForeignKeyConstraint(value[0], value[1], name=name)) + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) def last_inserted_ids(self): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 291637e9e..5a7369ccd 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -144,7 +144,7 @@ def reflecttable(connection, table, ischema_names): colargs= [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -175,7 +175,7 @@ def reflecttable(connection, table, ischema_names): ) #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) if type=='PRIMARY KEY': - table.c[constrained_column]._set_primary_key() + table.primary_key.add(table.c[constrained_column]) elif type=='FOREIGN KEY': try: fk = fks[constraint_name] @@ -196,5 +196,5 @@ def reflecttable(connection, table, ischema_names): fk[1].append(refspec) for name, value in fks.iteritems(): - table.append_item(ForeignKeyConstraint(value[0], value[1], name=name)) + table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name)) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 3d65abf0c..d23c41730 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -446,7 +446,7 @@ class MSSQLDialect(ansisql.ANSIDialect): if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -478,7 +478,7 @@ class MSSQLDialect(ansisql.ANSIDialect): c = connection.execute(s) for row in c: if 'PRIMARY' in row[TC.c.constraint_type.name]: - table.c[row[0]]._set_primary_key() + table.primary_key.add(table.c[row[0]]) # Foreign key constraints @@ -498,13 +498,13 @@ class MSSQLDialect(ansisql.ANSIDialect): scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r if rfknm != fknm: if fknm: - table.append_item(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) fknm, scols, rcols = (rfknm, [], []) if (not scol in scols): scols.append(scol) if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol)) if fknm and scols: - table.append_item(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 4443814c5..2fa7e9227 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -353,7 +353,7 @@ class MySQLDialect(ansisql.ANSIDialect): colargs= [] if default: colargs.append(schema.PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, *colargs, + table.append_column(schema.Column(name, coltype, *colargs, **dict(primary_key=primary_key, nullable=nullable, ))) @@ -397,7 +397,7 @@ class MySQLDialect(ansisql.ANSIDialect): refcols = [match.group('reftable') + "." + x for x in re.findall(r'`(.+?)`', match.group('refcols'))] schema.Table(match.group('reftable'), table.metadata, autoload=True, autoload_with=connection) constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name')) - table.append_item(constraint) + table.append_constraint(constraint) return tabletype diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index db82e3dea..b9aa09695 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -256,7 +256,7 @@ class OracleDialect(ansisql.ANSIDialect): if (name.upper() == name): name = name.lower() - table.append_item (schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) c = connection.execute(constraintSQL, {'table_name' : table.name.upper(), 'owner' : owner}) @@ -268,7 +268,7 @@ class OracleDialect(ansisql.ANSIDialect): #print "ROW:" , row (cons_name, cons_type, local_column, remote_table, remote_column) = row if cons_type == 'P': - table.c[local_column]._set_primary_key() + table.primary_key.add(table.c[local_column]) elif cons_type == 'R': try: fk = fks[cons_name] @@ -283,7 +283,7 @@ class OracleDialect(ansisql.ANSIDialect): fk[1].append(refspec) for name, value in fks.iteritems(): - table.append_item(schema.ForeignKeyConstraint(value[0], value[1], name=name)) + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) def do_executemany(self, c, statement, parameters, context=None): rowcount = 0 diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a28a22cd6..dad2d3bff 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -370,7 +370,7 @@ class PGDialect(ansisql.ANSIDialect): colargs= [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) if not found_table: @@ -392,7 +392,7 @@ class PGDialect(ansisql.ANSIDialect): if row is None: break pk = row[0] - table.c[pk]._set_primary_key() + table.primary_key.add(table.c[pk]) # Foreign keys FK_SQL = """ @@ -443,7 +443,7 @@ class PGDialect(ansisql.ANSIDialect): for column in referred_columns: refspec.append(".".join([referred_table, column])) - table.append_item(ForeignKeyConstraint(constrained_columns, refspec, row['conname'])) + table.append_constraint(ForeignKeyConstraint(constrained_columns, refspec, row['conname'])) class PGCompiler(ansisql.ANSICompiler): @@ -502,13 +502,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not sequence.optional and not self.engine.dialect.has_sequence(self.connection, sequence.name): + if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class PGSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if not sequence.optional and self.engine.dialect.has_sequence(self.connection, sequence.name): + if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 80d5a7d2a..90cd66dd3 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -199,7 +199,7 @@ class SQLiteDialect(ansisql.ANSIDialect): colargs= [] if has_default: colargs.append(PassiveDefault('?')) - table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) + table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -228,7 +228,7 @@ class SQLiteDialect(ansisql.ANSIDialect): if refspec not in fk[1]: fk[1].append(refspec) for name, value in fks.iteritems(): - table.append_item(schema.ForeignKeyConstraint(value[0], value[1])) + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1])) # check for UNIQUE indexes c = connection.execute("PRAGMA index_list(" + table.name + ")", {}) unique_indexes = [] @@ -250,8 +250,7 @@ class SQLiteDialect(ansisql.ANSIDialect): col = table.columns[row[2]] # unique index that includes the pk is considered a multiple primary key for col in cols: - column = table.columns[col] - table.columns[col]._set_primary_key() + table.primary_key.add(table.columns[col]) class SQLiteCompiler(ansisql.ANSICompiler): def visit_cast(self, cast): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 6d0cf2eb3..4ba5e1115 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -421,7 +421,7 @@ class ComposedSQLEngine(sql.Engine, Connectable): else: conn = connection try: - element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs)) + element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs), traverse=False) finally: if connection is None: conn.close() diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5afd3e1b6..462c5e799 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -366,10 +366,8 @@ class Query(object): if not distinct and order_by: s2.order_by(*util.to_list(order_by)) s3 = s2.alias('tbl_row_count') - crit = [] - for i in range(0, len(self.table.primary_key)): - crit.append(s3.primary_key[i] == self.table.primary_key[i]) - statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True, for_update=for_update) + crit = s3.primary_key==self.table.primary_key + statement = sql.select([], crit, from_obj=[self.table], use_labels=True, for_update=for_update) # now for the order by, convert the columns to their corresponding columns # in the "rowcount" query, and tack that new order by onto the "rowcount" query if order_by: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 05753e424..5728d7c37 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -19,7 +19,7 @@ import sqlalchemy import copy, re, string __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', - 'PrimaryKeyConstraint', + 'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): @@ -99,36 +99,33 @@ def _get_table_key(name, schema): class TableSingleton(type): """a metaclass used by the Table object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): + if isinstance(metadata, sql.Engine): + # backwards compatibility - get a BoundSchema associated with the engine + engine = metadata + if not hasattr(engine, '_legacy_metadata'): + engine._legacy_metadata = BoundMetaData(engine) + metadata = engine._legacy_metadata + elif metadata is not None and not isinstance(metadata, MetaData): + # they left MetaData out, so assume its another SchemaItem, add it to *args + args = list(args) + args.insert(0, metadata) + metadata = None + + if metadata is None: + metadata = default_metadata + + name = str(name) # in case of incoming unicode + schema = kwargs.get('schema', None) + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', False) + mustexist = kwargs.pop('mustexist', False) + useexisting = kwargs.pop('useexisting', False) + key = _get_table_key(name, schema) try: - if isinstance(metadata, sql.Engine): - # backwards compatibility - get a BoundSchema associated with the engine - engine = metadata - if not hasattr(engine, '_legacy_metadata'): - engine._legacy_metadata = BoundMetaData(engine) - metadata = engine._legacy_metadata - elif metadata is not None and not isinstance(metadata, MetaData): - # they left MetaData out, so assume its another SchemaItem, add it to *args - args = list(args) - args.insert(0, metadata) - metadata = None - - if metadata is None: - metadata = default_metadata - - name = str(name) # in case of incoming unicode - schema = kwargs.get('schema', None) - autoload = kwargs.pop('autoload', False) - autoload_with = kwargs.pop('autoload_with', False) - redefine = kwargs.pop('redefine', False) - mustexist = kwargs.pop('mustexist', False) - useexisting = kwargs.pop('useexisting', False) - key = _get_table_key(name, schema) table = metadata.tables[key] if len(args): - if redefine: - table._reload_values(*args) - elif not useexisting: - raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name)) + if not useexisting: + raise exceptions.ArgumentError("Table '%s.%s' is already defined for this MetaData instance." % (schema, name)) return table except KeyError: if mustexist: @@ -145,7 +142,7 @@ class TableSingleton(type): else: metadata.get_engine().reflecttable(table) except exceptions.NoSuchTableError: - table.deregister() + del metadata.tables[key] raise # initialize all the column, etc. objects. done after # reflection to allow user-overrides @@ -210,8 +207,8 @@ class Table(SchemaItem, sql.TableClause): super(Table, self).__init__(name) self._metadata = metadata self.schema = kwargs.pop('schema', None) - self.indexes = util.OrderedProperties() - self.constraints = [] + self.indexes = util.Set() + self.constraints = util.Set() self.primary_key = PrimaryKeyConstraint() self.quote = kwargs.get('quote', False) self.quote_schema = kwargs.get('quote_schema', False) @@ -237,7 +234,7 @@ class Table(SchemaItem, sql.TableClause): if getattr(self, '_primary_key', None) in self.constraints: self.constraints.remove(self._primary_key) self._primary_key = pk - self.constraints.append(pk) + self.constraints.add(pk) primary_key = property(lambda s:s._primary_key, _set_primary_key) def _derived_metadata(self): @@ -251,93 +248,45 @@ class Table(SchemaItem, sql.TableClause): def __str__(self): return _get_table_key(self.name, self.schema) - - def _reload_values(self, *args): - """clear out the columns and other properties of this Table, and reload them from the - given argument list. This is used with the "redefine" keyword argument sent to the - metaclass constructor.""" - self._clear() - - self._init_items(*args) - def append_item(self, item): - """appends a Column item or other schema item to this Table.""" - self._init_items(item) - def append_column(self, column): - if not column.hidden: - self._columns[column.key] = column - if column.primary_key: - self.primary_key.append(column) - column.table = self + """append a Column to this Table.""" + column._set_parent(self) + def append_constraint(self, constraint): + """append a Constraint to this Table.""" + constraint._set_parent(self) - def append_index(self, index): - self.indexes[index.name] = index - def _get_parent(self): return self._metadata def _set_parent(self, metadata): metadata.tables[_get_table_key(self.name, self.schema)] = self self._metadata = metadata - def accept_schema_visitor(self, visitor): - """traverses the given visitor across the Column objects inside this Table, - then calls the visit_table method on the visitor.""" - for c in self.columns: - c.accept_schema_visitor(visitor) + def accept_schema_visitor(self, visitor, traverse=True): + if traverse: + for c in self.columns: + c.accept_schema_visitor(visitor, True) return visitor.visit_table(self) - def append_index_column(self, column, index=None, unique=None): - """Add an index or a column to an existing index of the same name. - """ - if index is not None and unique is not None: - raise ValueError("index and unique may not both be specified") - if index: - if index is True: - name = 'ix_%s' % column._label - else: - name = index - elif unique: - if unique is True: - name = 'ux_%s' % column._label - else: - name = unique - # find this index in self.indexes - # add this column to it if found - # otherwise create new - try: - index = self.indexes[name] - index.append_column(column) - except KeyError: - index = Index(name, column, unique=unique) - return index - - def deregister(self): - """remove this table from it's owning metadata. - - this does not issue a SQL DROP statement.""" - key = _get_table_key(self.name, self.schema) - del self.metadata.tables[key] - - def exists(self, engine=None): - if engine is None: - engine = self.get_engine() + def exists(self, connectable=None): + """return True if this table exists.""" + if connectable is None: + connectable = self.get_engine() def do(conn): e = conn.engine return e.dialect.has_table(conn, self.name) - return engine.run_callable(do) + return connectable.run_callable(do) def create(self, connectable=None, checkfirst=False): - if connectable is not None: - connectable.create(self, checkfirst=checkfirst) - else: - self.get_engine().create(self, checkfirst=checkfirst) - return self + """issue a CREATE statement for this table. + + see also metadata.create_all().""" + self.metadata.create_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) def drop(self, connectable=None, checkfirst=False): - if connectable is not None: - connectable.drop(self, checkfirst=checkfirst) - else: - self.get_engine().drop(self, checkfirst=checkfirst) + """issue a DROP statement for this table. + + see also metadata.drop_all().""" + self.metadata.drop_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) def tometadata(self, metadata, schema=None): """return a copy of this Table associated with a different MetaData.""" try: @@ -389,17 +338,16 @@ class Column(SchemaItem, sql.ColumnClause): table's list of columns. Used for the "oid" column, which generally isnt in column lists. - index=None : True or index name. Indicates that this column is - indexed. Pass true to autogenerate the index name. Pass a string to - specify the index name. Multiple columns that specify the same index - name will all be included in the index, in the order of their - creation. + index=False : Indicates that this column is + indexed. The name of the index is autogenerated. + to specify indexes with explicit names or indexes that contain multiple + columns, use the Index construct instead. - unique=None : True or index name. Indicates that this column is - indexed in a unique index . Pass true to autogenerate the index - name. Pass a string to specify the index name. Multiple columns that - specify the same index name will all be included in the index, in the - order of their creation. + unique=False : Indicates that this column + contains a unique constraint, or if index=True as well, indicates + that the Index should be created with the unique flag. + To specify multiple columns in the constraint/index or to specify an + explicit name, use the UniqueConstraint or Index constructs instead. autoincrement=True : Indicates that integer-based primary key columns should have autoincrementing behavior, if supported by the underlying database. This will affect CREATE TABLE statements such that they will @@ -430,9 +378,8 @@ class Column(SchemaItem, sql.ColumnClause): self._set_casing_strategy(name, kwargs) self.onupdate = kwargs.pop('onupdate', None) self.autoincrement = kwargs.pop('autoincrement', True) + self.constraints = util.Set() self.__originating_column = self - if self.index is not None and self.unique is not None: - raise exceptions.ArgumentError("Column may not define both index and unique") self._foreign_keys = util.Set() if len(kwargs): raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) @@ -455,7 +402,10 @@ class Column(SchemaItem, sql.ColumnClause): return self.table.metadata def _get_engine(self): return self.table.engine - + + def append_foreign_key(self, fk): + fk._set_parent(self) + def __repr__(self): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + @@ -463,33 +413,33 @@ class Column(SchemaItem, sql.ColumnClause): ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']] , ',') - def append_item(self, item): - self._init_items(item) - - def _set_primary_key(self): - if self.primary_key: - return - self.primary_key = True - self.nullable = False - self.table.primary_key.append(self) - def _get_parent(self): return self.table + def _set_parent(self, table): if getattr(self, 'table', None) is not None: raise exceptions.ArgumentError("this Column already has a table!") - table.append_column(self) - if self.index or self.unique: - table.append_index_column(self, index=self.index, - unique=self.unique) - + if not self.hidden: + table._columns.add(self) + if self.primary_key: + table.primary_key.add(self) + self.table = table + + if self.index: + if isinstance(self.index, str): + raise exceptions.ArgumentError("The 'index' keyword argument on Column is boolean only. To create indexes with a specific name, append an explicit Index object to the Table's list of elements.") + Index('ix_%s' % self._label, self, unique=self.unique) + elif self.unique: + if isinstance(self.unique, str): + raise exceptions.ArgumentError("The 'unique' keyword argument on Column is boolean only. To create unique constraints or indexes with a specific name, append an explicit UniqueConstraint or Index object to the Table's list of elements.") + table.append_constraint(UniqueConstraint(self.key)) + + toinit = list(self.args) if self.default is not None: - self.default = ColumnDefault(self.default) - self._init_items(self.default) + toinit.append(ColumnDefault(self.default)) if self.onupdate is not None: - self.onupdate = ColumnDefault(self.onupdate, for_update=True) - self._init_items(self.onupdate) - self._init_items(*self.args) + toinit.append(ColumnDefault(self.onupdate, for_update=True)) + self._init_items(*toinit) self.args = None def copy(self): @@ -507,9 +457,9 @@ class Column(SchemaItem, sql.ColumnClause): c.orig_set = self.orig_set c.__originating_column = self.__originating_column if not c.hidden: - selectable.columns[c.key] = c + selectable.columns.add(c) if self.primary_key: - selectable.primary_key.append(c) + selectable.primary_key.add(c) [c._init_items(f) for f in fk] return c @@ -519,15 +469,18 @@ class Column(SchemaItem, sql.ColumnClause): return self.__originating_column._get_case_sensitive() case_sensitive = property(_case_sens) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """traverses the given visitor to this Column's default and foreign key object, then calls visit_column on the visitor.""" - if self.default is not None: - self.default.accept_schema_visitor(visitor) - if self.onupdate is not None: - self.onupdate.accept_schema_visitor(visitor) - for f in self.foreign_keys: - f.accept_schema_visitor(visitor) + if traverse: + if self.default is not None: + self.default.accept_schema_visitor(visitor, traverse=True) + if self.onupdate is not None: + self.onupdate.accept_schema_visitor(visitor, traverse=True) + for f in self.foreign_keys: + f.accept_schema_visitor(visitor, traverse=True) + for constraint in self.constraints: + constraint.accept_schema_visitor(visitor, traverse=True) visitor.visit_column(self) @@ -538,7 +491,7 @@ class ForeignKey(SchemaItem): One or more ForeignKey objects are used within a ForeignKeyConstraint object which represents the table-level constraint definition.""" - def __init__(self, column, constraint=None): + def __init__(self, column, constraint=None, use_alter=False): """Construct a new ForeignKey object. "column" can be a schema.Column object representing the relationship, @@ -553,6 +506,7 @@ class ForeignKey(SchemaItem): self._colspec = column self._column = None self.constraint = constraint + self.use_alter = use_alter def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) @@ -611,7 +565,7 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_foreign_key method on the given visitor.""" visitor.visit_foreign_key(self) @@ -621,17 +575,13 @@ class ForeignKey(SchemaItem): self.parent = column if self.constraint is None and isinstance(self.parent.table, Table): - self.constraint = ForeignKeyConstraint([],[]) - self.parent.table.append_item(self.constraint) + self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter) + self.parent.table.append_constraint(self.constraint) self.constraint._append_fk(self) - # if a foreign key was already set up for the parent column, replace it with - # this one - #if self.parent.foreign_key is not None: - # self.parent.table.foreign_keys.remove(self.parent.foreign_key) - #self.parent.foreign_key = self self.parent.foreign_keys.add(self) self.parent.table.foreign_keys.add(self) + class DefaultGenerator(SchemaItem): """Base class for column "default" values.""" def __init__(self, for_update=False, metadata=None): @@ -661,7 +611,7 @@ class PassiveDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): return visitor.visit_passive_default(self) def __repr__(self): return "PassiveDefault(%s)" % repr(self.arg) @@ -672,7 +622,7 @@ class ColumnDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_column_default method on the given visitor.""" if self.for_update: return visitor.visit_column_onupdate(self) @@ -704,57 +654,66 @@ class Sequence(DefaultGenerator): return self def drop(self): self.get_engine().drop(self) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) class Constraint(SchemaItem): """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint. - Also follows list behavior with regards to the underlying set of columns.""" + Implements a hybrid of dict/setlike behavior with regards to the list of underying columns""" def __init__(self, name=None): self.name = name - self.columns = [] + self.columns = sql.ColumnCollection() def __contains__(self, x): return x in self.columns + def keys(self): + return self.columns.keys() def __add__(self, other): return self.columns + other def __iter__(self): return iter(self.columns) def __len__(self): return len(self.columns) - def __getitem__(self, index): - return self.columns[index] - def __setitem__(self, index, item): - self.columns[index] = item def copy(self): raise NotImplementedError() def _get_parent(self): return getattr(self, 'table', None) - + +class CheckConstraint(Constraint): + def __init__(self, sqltext, name=None): + super(CheckConstraint, self).__init__(name) + self.sqltext = sqltext + def accept_schema_visitor(self, visitor, traverse=True): + visitor.visit_check_constraint(self) + def _set_parent(self, parent): + self.parent = parent + parent.constraints.add(self) + class ForeignKeyConstraint(Constraint): """table-level foreign key constraint, represents a colleciton of ForeignKey objects.""" - def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None): + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False): super(ForeignKeyConstraint, self).__init__(name) self.__colnames = columns self.__refcolnames = refcolumns - self.elements = [] + self.elements = util.Set() self.onupdate = onupdate self.ondelete = ondelete + self.use_alter = use_alter def _set_parent(self, table): self.table = table - table.constraints.append(self) + table.constraints.add(self) for (c, r) in zip(self.__colnames, self.__refcolnames): - self.append(c,r) - def accept_schema_visitor(self, visitor): + self.append_element(c,r) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_foreign_key_constraint(self) - def append(self, col, refcol): + def append_element(self, col, refcol): fk = ForeignKey(refcol, constraint=self) fk._set_parent(self.table.c[col]) self._append_fk(fk) def _append_fk(self, fk): - self.columns.append(self.table.c[fk.parent.key]) - self.elements.append(fk) + self.columns.add(self.table.c[fk.parent.key]) + self.elements.add(fk) def copy(self): return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete) @@ -766,37 +725,37 @@ class PrimaryKeyConstraint(Constraint): self.table = table table.primary_key = self for c in self.__colnames: - self.append(table.c[c]) - def accept_schema_visitor(self, visitor): + self.append_column(table.c[c]) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_primary_key_constraint(self) - def append(self, col): - # TODO: change "columns" to a key-sensitive set ? - for c in self.columns: - if c.key == col.key: - self.columns.remove(c) - self.columns.append(col) + def add(self, col): + self.append_column(col) + def append_column(self, col): + self.columns.add(col) col.primary_key=True def copy(self): return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) - + def __eq__(self, other): + return self.columns == other + class UniqueConstraint(Constraint): - def __init__(self, name=None, *columns): - super(Constraint, self).__init__(name) + def __init__(self, *columns, **kwargs): + super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None)) self.__colnames = list(columns) def _set_parent(self, table): self.table = table - table.constraints.append(self) + table.constraints.add(self) for c in self.__colnames: - self.append(table.c[c]) - def append(self, col): - self.columns.append(col) - def accept_schema_visitor(self, visitor): + self.append_column(table.c[c]) + def append_column(self, col): + self.columns.add(col) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_unique_constraint(self) class Index(SchemaItem): """Represents an index of columns from a database table """ - def __init__(self, name, *columns, **kw): + def __init__(self, name, *columns, **kwargs): """Constructs an index object. Arguments are: name : the name of the index @@ -811,7 +770,7 @@ class Index(SchemaItem): self.name = name self.columns = [] self.table = None - self.unique = kw.pop('unique', False) + self.unique = kwargs.pop('unique', False) self._init_items(*columns) def _derived_metadata(self): @@ -821,12 +780,15 @@ class Index(SchemaItem): self.append_column(column) def _get_parent(self): return self.table + def _set_parent(self, table): + self.table = table + table.indexes.add(self) + def append_column(self, column): # make sure all columns are from the same table # and no column is repeated if self.table is None: - self.table = column.table - self.table.append_index(self) + self._set_parent(column.table) elif column.table != self.table: # all columns muse be from same table raise exceptions.ArgumentError("All index columns must be from same table. " @@ -850,7 +812,7 @@ class Index(SchemaItem): connectable.drop(self) else: self.get_engine().drop(self) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_index(self) def __str__(self): return repr(self) @@ -863,7 +825,6 @@ class Index(SchemaItem): class MetaData(SchemaItem): """represents a collection of Tables and their associated schema constructs.""" def __init__(self, name=None, **kwargs): - # a dictionary that stores Table objects keyed off their name (and possibly schema name) self.tables = {} self.name = name self._set_casing_strategy(name, kwargs) @@ -871,11 +832,18 @@ class MetaData(SchemaItem): return False def clear(self): self.tables.clear() - def table_iterator(self, reverse=True): - return self._sort_tables(self.tables.values(), reverse=reverse) + + def table_iterator(self, reverse=True, tables=None): + import sqlalchemy.sql_util + if tables is None: + tables = self.tables.values() + else: + tables = util.Set(tables).intersection(self.tables.values()) + sorter = sqlalchemy.sql_util.TableCollection(list(tables)) + return iter(sorter.sort(reverse=reverse)) def _get_parent(self): return None - def create_all(self, connectable=None, tables=None, engine=None): + def create_all(self, connectable=None, tables=None, checkfirst=True): """create all tables stored in this metadata. This will conditionally create tables depending on if they do not yet @@ -884,28 +852,13 @@ class MetaData(SchemaItem): connectable - a Connectable used to access the database; or use the engine bound to this MetaData. - tables - optional list of tables to create - - engine - deprecated argument.""" - if not tables: - tables = self.tables.values() - - if connectable is None: - connectable = engine - + tables - optional list of tables, which is a subset of the total + tables in the MetaData (others are ignored)""" if connectable is None: connectable = self.get_engine() - - def do(conn): - e = conn.engine - ts = self._sort_tables( tables ) - for table in ts: - if e.dialect.has_table(conn, table.name): - continue - conn.create(table) - connectable.run_callable(do) + connectable.create(self, checkfirst=checkfirst, tables=tables) - def drop_all(self, connectable=None, tables=None, engine=None): + def drop_all(self, connectable=None, tables=None, checkfirst=True): """drop all tables stored in this metadata. This will conditionally drop tables depending on if they currently @@ -914,33 +867,17 @@ class MetaData(SchemaItem): connectable - a Connectable used to access the database; or use the engine bound to this MetaData. - tables - optional list of tables to drop - - engine - deprecated argument.""" - if not tables: - tables = self.tables.values() - - if connectable is None: - connectable = engine - + tables - optional list of tables, which is a subset of the total + tables in the MetaData (others are ignored) + """ if connectable is None: connectable = self.get_engine() - - def do(conn): - e = conn.engine - ts = self._sort_tables( tables, reverse=True ) - for table in ts: - if e.dialect.has_table(conn, table.name): - conn.drop(table) - connectable.run_callable(do) + connectable.drop(self, checkfirst=checkfirst, tables=tables) - def _sort_tables(self, tables, reverse=False): - import sqlalchemy.sql_util - sorter = sqlalchemy.sql_util.TableCollection() - for t in tables: - sorter.add(t) - return sorter.sort(reverse=reverse) - + + def accept_schema_visitor(self, visitor, traverse=True): + visitor.visit_metadata(self) + def _derived_metadata(self): return self def _get_engine(self): @@ -1029,6 +966,8 @@ class SchemaVisitor(sql.ClauseVisitor): pass def visit_unique_constraint(self, constraint): pass + def visit_check_constraint(self, constraint): + pass default_metadata = DynamicMetaData('default') diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c113edaa3..6f51ccbe9 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -658,6 +658,17 @@ class ColumnElement(Selectable, CompareMixin): else: return self +class ColumnCollection(util.OrderedProperties): + def add(self, column): + self[column.key] = column + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" def __init__(self, name=None): @@ -671,7 +682,7 @@ class FromClause(Selectable): visitor.visit_fromclause(self) def count(self, whereclause=None, **params): if len(self.primary_key): - col = self.primary_key[0] + col = list(self.primary_key)[0] else: col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) @@ -735,8 +746,8 @@ class FromClause(Selectable): if hasattr(self, '_columns'): # TODO: put a mutex here ? this is a key place for threading probs return - self._columns = util.OrderedProperties() - self._primary_key = [] + self._columns = ColumnCollection() + self._primary_key = ColumnCollection() self._foreign_keys = util.Set() self._orig_cols = {} export = self._exportable_columns() @@ -1082,7 +1093,7 @@ class Join(FromClause): def _proxy_column(self, column): self._columns[column._label] = column if column.primary_key: - self._primary_key.append(column) + self._primary_key.add(column) for f in column.foreign_keys: self._foreign_keys.add(f) return column @@ -1257,9 +1268,9 @@ class TableClause(FromClause): def __init__(self, name, *columns): super(TableClause, self).__init__(name) self.name = self.fullname = name - self._columns = util.OrderedProperties() + self._columns = ColumnCollection() self._foreign_keys = util.Set() - self._primary_key = [] + self._primary_key = util.Set() for c in columns: self.append_column(c) self._oid_column = ColumnClause('oid', self, hidden=True) @@ -1282,16 +1293,6 @@ class TableClause(FromClause): return self._orig_cols original_columns = property(_orig_columns) - def _clear(self): - """clears all attributes on this TableClause so that new items can be added again""" - self.columns.clear() - self.foreign_keys[:] = [] - self.primary_key[:] = [] - try: - delattr(self, '_orig_cols') - except AttributeError: - pass - def accept_visitor(self, visitor): visitor.visit_table(self) def _exportable_columns(self): @@ -1305,7 +1306,7 @@ class TableClause(FromClause): data[self] = self def count(self, whereclause=None, **params): if len(self.primary_key): - col = self.primary_key[0] + col = list(self.primary_key)[0] else: col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 5f243ae04..d5c6a3b92 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -87,6 +87,8 @@ class OrderedProperties(object): return len(self.__data) def __iter__(self): return self.__data.itervalues() + def __add__(self, other): + return list(self) + list(other) def __setitem__(self, key, object): self.__data[key] = object def __getitem__(self, key): |