summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/databases/mysql.py19
-rw-r--r--test/dialect/mysql.py62
2 files changed, 74 insertions, 7 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index f0b18d3ac..2d9c3af4c 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -1729,13 +1729,16 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
if not column.nullable:
colspec.append('NOT NULL')
- # FIXME: #649 ASAP
- if column.primary_key:
- if (len(column.foreign_keys)==0
- and first_pk
- and column.autoincrement
- and isinstance(column.type, sqltypes.Integer)):
- colspec.append('AUTO_INCREMENT')
+ if column.primary_key and column.autoincrement:
+ try:
+ first = [c for c in column.table.primary_key.columns
+ if (c.autoincrement and
+ isinstance(c.type, sqltypes.Integer) and
+ not c.foreign_keys)].pop(0)
+ if column is first:
+ colspec.append('AUTO_INCREMENT')
+ except IndexError:
+ pass
return ' '.join(colspec)
@@ -1909,6 +1912,8 @@ class MySQLSchemaReflector(object):
# AUTO_INCREMENT
if spec.get('autoincr', False):
col_kw['autoincrement'] = True
+ elif issubclass(col_type, sqltypes.Integer):
+ col_kw['autoincrement'] = False
# DEFAULT
default = spec.get('default', None)
diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py
index 03a87a0ba..ab3e49f93 100644
--- a/test/dialect/mysql.py
+++ b/test/dialect/mysql.py
@@ -600,6 +600,68 @@ class TypesTest(AssertMixin):
m.drop_all()
+ @testing.supported('mysql')
+ def test_autoincrement(self):
+ meta = MetaData(testbase.db)
+ try:
+ Table('ai_1', meta,
+ Column('int_y', Integer, primary_key=True),
+ Column('int_n', Integer, PassiveDefault('0'),
+ primary_key=True))
+ Table('ai_2', meta,
+ Column('int_y', Integer, primary_key=True),
+ Column('int_n', Integer, PassiveDefault('0'),
+ primary_key=True))
+ Table('ai_3', meta,
+ Column('int_n', Integer, PassiveDefault('0'),
+ primary_key=True, autoincrement=False),
+ Column('int_y', Integer, primary_key=True))
+ Table('ai_4', meta,
+ Column('int_n', Integer, PassiveDefault('0'),
+ primary_key=True, autoincrement=False),
+ Column('int_n2', Integer, PassiveDefault('0'),
+ primary_key=True, autoincrement=False))
+ Table('ai_5', meta,
+ Column('int_y', Integer, primary_key=True),
+ Column('int_n', Integer, PassiveDefault('0'),
+ primary_key=True, autoincrement=False))
+ Table('ai_6', meta,
+ Column('o1', String(1), PassiveDefault('x'),
+ primary_key=True),
+ Column('int_y', Integer, primary_key=True))
+ Table('ai_7', meta,
+ Column('o1', String(1), PassiveDefault('x'),
+ primary_key=True),
+ Column('o2', String(1), PassiveDefault('x'),
+ primary_key=True),
+ Column('int_y', Integer, primary_key=True))
+ Table('ai_8', meta,
+ Column('o1', String(1), PassiveDefault('x'),
+ primary_key=True),
+ Column('o2', String(1), PassiveDefault('x'),
+ primary_key=True))
+ meta.create_all()
+
+ table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
+ 'ai_5', 'ai_6', 'ai_7', 'ai_8']
+ mr = MetaData(testbase.db)
+ mr.reflect(only=table_names)
+
+ for tbl in [mr.tables[name] for name in table_names]:
+ for c in tbl.c:
+ if c.name.startswith('int_y'):
+ assert c.autoincrement
+ elif c.name.startswith('int_n'):
+ assert not c.autoincrement
+ tbl.insert().execute()
+ if 'int_y' in tbl.c:
+ assert select([tbl.c.int_y]).scalar() == 1
+ assert list(tbl.select().execute().fetchone()).count(1) == 1
+ else:
+ assert 1 not in list(tbl.select().execute().fetchone())
+ finally:
+ meta.drop_all()
+
def assert_eq(self, got, wanted):
if got != wanted:
print "Expected %s" % wanted