diff options
author | Jason Kirtland <jek@discorporate.us> | 2007-08-12 01:11:44 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2007-08-12 01:11:44 +0000 |
commit | 138eee02f58cd363716d709e27cedc76df0faf9a (patch) | |
tree | 22108659206ce9d83ebb17733894490e43ac48d7 | |
parent | 2d8b5bb4f36e5624f25b170391fe42d3bfbeb623 (diff) | |
download | sqlalchemy-138eee02f58cd363716d709e27cedc76df0faf9a.tar.gz |
Allow auto_increment on any pk column, not just the first.
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 19 | ||||
-rw-r--r-- | test/dialect/mysql.py | 62 |
2 files changed, 74 insertions, 7 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index f0b18d3ac..2d9c3af4c 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1729,13 +1729,16 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec.append('NOT NULL') - # FIXME: #649 ASAP - if column.primary_key: - if (len(column.foreign_keys)==0 - and first_pk - and column.autoincrement - and isinstance(column.type, sqltypes.Integer)): - colspec.append('AUTO_INCREMENT') + if column.primary_key and column.autoincrement: + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + isinstance(c.type, sqltypes.Integer) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('AUTO_INCREMENT') + except IndexError: + pass return ' '.join(colspec) @@ -1909,6 +1912,8 @@ class MySQLSchemaReflector(object): # AUTO_INCREMENT if spec.get('autoincr', False): col_kw['autoincrement'] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw['autoincrement'] = False # DEFAULT default = spec.get('default', None) diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index 03a87a0ba..ab3e49f93 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -600,6 +600,68 @@ class TypesTest(AssertMixin): m.drop_all() + @testing.supported('mysql') + def test_autoincrement(self): + meta = MetaData(testbase.db) + try: + Table('ai_1', meta, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True)) + Table('ai_2', meta, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True)) + Table('ai_3', meta, + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False), + Column('int_y', Integer, primary_key=True)) + Table('ai_4', meta, + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False), + Column('int_n2', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False)) + Table('ai_5', meta, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False)) + Table('ai_6', meta, + Column('o1', String(1), PassiveDefault('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_7', meta, + Column('o1', String(1), PassiveDefault('x'), + primary_key=True), + Column('o2', String(1), PassiveDefault('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_8', meta, + Column('o1', String(1), PassiveDefault('x'), + primary_key=True), + Column('o2', String(1), PassiveDefault('x'), + primary_key=True)) + meta.create_all() + + table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', + 'ai_5', 'ai_6', 'ai_7', 'ai_8'] + mr = MetaData(testbase.db) + mr.reflect(only=table_names) + + for tbl in [mr.tables[name] for name in table_names]: + for c in tbl.c: + if c.name.startswith('int_y'): + assert c.autoincrement + elif c.name.startswith('int_n'): + assert not c.autoincrement + tbl.insert().execute() + if 'int_y' in tbl.c: + assert select([tbl.c.int_y]).scalar() == 1 + assert list(tbl.select().execute().fetchone()).count(1) == 1 + else: + assert 1 not in list(tbl.select().execute().fetchone()) + finally: + meta.drop_all() + def assert_eq(self, got, wanted): if got != wanted: print "Expected %s" % wanted |