summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-10-25 21:27:08 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-10-25 21:27:08 +0000
commiteb9763febe58655ca0f61fa758925c56b94ece9b (patch)
tree52b93cd7ef50ae799d16fd4bc9d1c5ff5fd34e41
parenta5f827b12dbceb1c6e8f8b787548b9de326fe076 (diff)
downloadsqlalchemy-eb9763febe58655ca0f61fa758925c56b94ece9b.tar.gz
- generalized Enum to issue a CHECK constraint + VARCHAR on default platform
- added native_enum=False flag to do the same on MySQL, PG, if desired
-rw-r--r--CHANGES13
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py11
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py11
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py10
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py2
-rw-r--r--lib/sqlalchemy/schema.py14
-rw-r--r--lib/sqlalchemy/sql/compiler.py29
-rw-r--r--lib/sqlalchemy/sql/expression.py6
-rw-r--r--lib/sqlalchemy/types.py46
-rw-r--r--test/dialect/test_mysql.py24
-rw-r--r--test/dialect/test_postgresql.py19
-rw-r--r--test/sql/test_types.py81
12 files changed, 226 insertions, 40 deletions
diff --git a/CHANGES b/CHANGES
index 7f217ef1d..3c9c52a3a 100644
--- a/CHANGES
+++ b/CHANGES
@@ -583,13 +583,12 @@ CHANGES
type. This means reflection now returns more accurate
information about reflected types.
- - Added a new Enum generic type, currently supported on
- Postgresql and MySQL. Enum is a schema-aware object
- to support databases which require specific DDL in
- order to use enum or equivalent; in the case of PG
- it handles the details of `CREATE TYPE`, and on
- other databases without native enum support can
- support generation of CHECK constraints.
+ - Added a new Enum generic type. Enum is a schema-aware object
+ to support databases which require specific DDL in order to
+ use enum or equivalent; in the case of PG it handles the
+ details of `CREATE TYPE`, and on other databases without
+ native enum support will by generate VARCHAR + an inline CHECK
+ constraint to enforce the enum.
[ticket:1109] [ticket:1511]
- PickleType now uses == for comparison of values when
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index d7ea358b5..e54b7687d 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1351,6 +1351,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
return ' '.join(colspec)
+ def visit_enum_constraint(self, constraint):
+ if not constraint.type.native_enum:
+ return super(MySQLDDLCompiler, self).visit_enum_constraint(constraint)
+ else:
+ return None
+
def post_create_table(self, table):
"""Build table-level CREATE options like ENGINE and COLLATE."""
@@ -1576,7 +1582,10 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
return self.visit_BLOB(type_)
def visit_enum(self, type_):
- return self.visit_ENUM(type_)
+ if not type_.native_enum:
+ return super(MySQLTypeCompiler, self).visit_enum(type_)
+ else:
+ return self.visit_ENUM(type_)
def visit_BINARY(self, type_):
if type_.length:
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 6108d3d66..e4d3b312b 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -251,13 +251,18 @@ class Oracle_cx_oracleExecutionContext(OracleExecutionContext):
for bind, name in self.compiled.bind_names.iteritems():
if name in self.out_parameters:
type = bind.type
- result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
+ result_processor = type.dialect_impl(self.dialect).\
+ result_processor(self.dialect)
if result_processor is not None:
- out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
+ out_parameters[name] = \
+ result_processor(self.out_parameters[name].getvalue())
else:
out_parameters[name] = self.out_parameters[name].getvalue()
else:
- result.out_parameters = dict((k, v.getvalue()) for k, v in self.out_parameters.items())
+ result.out_parameters = dict(
+ (k, v.getvalue())
+ for k, v in self.out_parameters.items()
+ )
return result
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 1f4858cdd..26c4a8a97 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -330,7 +330,10 @@ class PGDDLCompiler(compiler.DDLCompiler):
def visit_drop_sequence(self, drop):
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
-
+ def visit_enum_constraint(self, constraint):
+ if not constraint.type.native_enum:
+ return super(PGDDLCompiler, self).visit_enum_constraint(constraint)
+
def visit_create_enum_type(self, create):
type_ = create.element
@@ -400,7 +403,10 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
return self.visit_TIMESTAMP(type_)
def visit_enum(self, type_):
- return self.visit_ENUM(type_)
+ if not type_.native_enum:
+ return super(PGTypeCompiler, self).visit_enum(type_)
+ else:
+ return self.visit_ENUM(type_)
def visit_ENUM(self, type_):
return self.dialect.identifier_preparer.format_type(type_)
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index c25e75f2c..86b2eacd3 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -236,7 +236,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
if not column.nullable:
colspec += " NOT NULL"
return colspec
-
+
class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
def visit_binary(self, type_):
return self.visit_BLOB(type_)
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index b99f79a8e..44f53f235 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -28,7 +28,7 @@ expressions.
"""
import re, inspect
-from sqlalchemy import types, exc, util, dialects
+from sqlalchemy import exc, util, dialects
from sqlalchemy.sql import expression, visitors
URL = None
@@ -765,12 +765,12 @@ class Column(SchemaItem, expression.ColumnClause):
table.append_constraint(UniqueConstraint(self.key))
for fn in self._table_events:
- fn(table)
+ fn(table, self)
del self._table_events
def _on_table_attach(self, fn):
if self.table is not None:
- fn(self.table)
+ fn(self.table, self)
else:
self._table_events.add(fn)
@@ -819,7 +819,7 @@ class Column(SchemaItem, expression.ColumnClause):
if self.primary_key:
selectable.primary_key.add(c)
for fn in c._table_events:
- fn(selectable)
+ fn(selectable, c)
del c._table_events
return c
@@ -1032,7 +1032,7 @@ class ForeignKey(SchemaItem):
self.parent.foreign_keys.add(self)
self.parent._on_table_attach(self._set_table)
- def _set_table(self, table):
+ def _set_table(self, table, column):
if self.constraint is None and isinstance(table, Table):
self.constraint = ForeignKeyConstraint(
[], [], use_alter=self.use_alter, name=self.name,
@@ -1181,11 +1181,9 @@ class Sequence(DefaultGenerator):
def _set_parent(self, column):
super(Sequence, self)._set_parent(column)
-# column.sequence = self
-
column._on_table_attach(self._set_table)
- def _set_table(self, table):
+ def _set_table(self, table, column):
self.metadata = table.metadata
@property
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index c1b421843..088ca1969 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -964,7 +964,10 @@ class DDLCompiler(engine.Compiled):
for column in table.columns:
text += separator
separator = ", \n"
- text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
+ text += "\t" + self.get_column_specification(
+ column,
+ first_pk=column.primary_key and not first_pk
+ )
if column.primary_key:
first_pk = True
const = " ".join(self.process(constraint) for constraint in column.constraints)
@@ -976,15 +979,18 @@ class DDLCompiler(engine.Compiled):
if table.primary_key:
text += ", \n\t" + self.process(table.primary_key)
- const = ", \n\t".join(
- self.process(constraint) for constraint in table.constraints
+ const = ", \n\t".join(p for p in
+ (self.process(constraint) for constraint in table.constraints
if constraint is not table.primary_key
and constraint.inline_ddl
- and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False))
+ and (
+ not self.dialect.supports_alter or
+ not getattr(constraint, 'use_alter', False)
+ )) if p is not None
)
if const:
text += ", \n\t" + const
-
+
text += "\n)%s\n\n" % self.post_create_table(table)
return text
@@ -1121,6 +1127,17 @@ class DDLCompiler(engine.Compiled):
text += self.define_constraint_deferrability(constraint)
return text
+ def visit_enum_constraint(self, constraint):
+ text = ""
+ if constraint.name is not None:
+ text += "CONSTRAINT %s " % \
+ self.preparer.format_constraint(constraint)
+ text += " CHECK (%s IN (%s))" % (
+ self.preparer.format_column(constraint.column),
+ ",".join("'%s'" % x for x in constraint.type.enums)
+ )
+ return text
+
def define_constraint_cascades(self, constraint):
text = ""
if constraint.ondelete is not None:
@@ -1247,7 +1264,7 @@ class GenericTypeCompiler(engine.TypeCompiler):
return self.visit_TEXT(type_)
def visit_enum(self, type_):
- raise NotImplementedError("Enum not supported generically")
+ return self.visit_VARCHAR(type_)
def visit_null(self, type_):
raise NotImplementedError("Can't generate DDL for the null type")
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b71c1892b..9324ed6a0 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -29,12 +29,12 @@ to stay the same in future releases.
import itertools, re
from operator import attrgetter
-from sqlalchemy import util, exc, types as sqltypes
+from sqlalchemy import util, exc #, types as sqltypes
from sqlalchemy.sql import operators
from sqlalchemy.sql.visitors import Visitable, cloned_traverse
import operator
-functions, schema, sql_util = None, None, None
+functions, schema, sql_util, sqltypes = None, None, None, None
DefaultDialect, ClauseAdapter, Annotated = None, None, None
__all__ = [
@@ -3071,7 +3071,7 @@ class TableClause(_Immutable, FromClause):
__visit_name__ = 'table'
named_with_column = True
-
+
def __init__(self, name, *columns):
super(TableClause, self).__init__()
self.name = self.fullname = name
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index ba1a3f907..27918e15c 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -24,7 +24,10 @@ import inspect
import datetime as dt
from decimal import Decimal as _python_Decimal
-from sqlalchemy import exc
+from sqlalchemy import exc, schema
+from sqlalchemy.sql import expression
+import sys
+schema.types = expression.sqltypes =sys.modules['sqlalchemy.types']
from sqlalchemy.util import pickle
from sqlalchemy.sql.visitors import Visitable
import sqlalchemy.util as util
@@ -809,8 +812,8 @@ class SchemaType(object):
def _set_parent(self, column):
column._on_table_attach(self._set_table)
-
- def _set_table(self, table):
+
+ def _set_table(self, table, column):
table.append_ddl_listener('before-create', self._on_table_create)
table.append_ddl_listener('after-drop', self._on_table_drop)
if self.metadata is None:
@@ -863,9 +866,11 @@ class SchemaType(object):
class Enum(String, SchemaType):
"""Generic Enum Type.
- Currently supported on MySQL and Postgresql, the Enum type
- provides a set of possible string values which the column is constrained
- towards.
+ The Enum type provides a set of possible string values which the
+ column is constrained towards.
+
+ By default, uses the backend's native ENUM type if available,
+ else uses VARCHAR + a CHECK constraint.
Keyword arguments which don't apply to a specific backend are ignored
by that backend.
@@ -895,6 +900,10 @@ class Enum(String, SchemaType):
or an explicitly named constraint in order to generate the type and/or
a table that uses it.
+ :param native_enum: Use the database's native ENUM type when available.
+ Defaults to True. When False, uses VARCHAR + check constraint
+ for all backends.
+
:param schema: Schemaname of this type. For types that exist on the target
database as an independent schema construct (Postgresql), this
parameter specifies the named schema in which the type is present.
@@ -909,6 +918,7 @@ class Enum(String, SchemaType):
def __init__(self, *enums, **kw):
self.enums = enums
+ self.native_enum = kw.pop('native_enum', True)
convert_unicode= kw.pop('convert_unicode', None)
assert_unicode = kw.pop('assert_unicode', None)
if convert_unicode is None:
@@ -919,11 +929,27 @@ class Enum(String, SchemaType):
else:
convert_unicode = False
+ if self.enums:
+ length =max(len(x) for x in self.enums)
+ else:
+ length = 0
String.__init__(self,
+ length =length,
convert_unicode=convert_unicode,
assert_unicode=assert_unicode
)
SchemaType.__init__(self, **kw)
+
+ def _set_table(self, table, column):
+ if self.native_enum:
+ SchemaType._set_table(self, table, column)
+
+ # this constraint DDL object is conditionally
+ # compiled by MySQL, Postgresql based on
+ # the native_enum flag.
+ table.append_constraint(
+ EnumConstraint(self, column)
+ )
def adapt(self, impltype):
return impltype(name=self.name,
@@ -935,6 +961,14 @@ class Enum(String, SchemaType):
*self.enums
)
+class EnumConstraint(schema.CheckConstraint):
+ __visit_name__ = 'enum_constraint'
+
+ def __init__(self, type_, column, **kw):
+ super(EnumConstraint, self).__init__('', name=type_.name, **kw)
+ self.type = type_
+ self.column = column
+
class PickleType(MutableType, TypeDecorator):
"""Holds Python objects.
diff --git a/test/dialect/test_mysql.py b/test/dialect/test_mysql.py
index 64f65d8f6..49dde1520 100644
--- a/test/dialect/test_mysql.py
+++ b/test/dialect/test_mysql.py
@@ -7,18 +7,19 @@ import sets
# end Py2K
from sqlalchemy import *
-from sqlalchemy import sql, exc
+from sqlalchemy import sql, exc, schema
from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.test.testing import eq_
from sqlalchemy.test import *
from sqlalchemy.test.engines import utf8_engine
-class TypesTest(TestBase, AssertsExecutionResults):
+class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
"Test MySQL column types"
__only_on__ = 'mysql'
-
+ __dialect__ = mysql.dialect()
+
@testing.uses_deprecated('Manually quoting ENUM value literals')
def test_basic(self):
meta1 = MetaData(testing.db)
@@ -643,6 +644,23 @@ class TypesTest(TestBase, AssertsExecutionResults):
finally:
metadata.drop_all()
+ def test_enum_compile(self):
+ e1 = Enum('x', 'y', 'z', name="somename")
+ t1 = Table('sometable', MetaData(), Column('somecolumn', e1))
+ self.assert_compile(
+ schema.CreateTable(t1),
+ "CREATE TABLE sometable (somecolumn ENUM('x','y','z'))"
+ )
+ t1 = Table('sometable', MetaData(),
+ Column('somecolumn', Enum('x', 'y', 'z', native_enum=False))
+ )
+ self.assert_compile(
+ schema.CreateTable(t1),
+ "CREATE TABLE sometable ("
+ "somecolumn VARCHAR(1), "
+ " CHECK (somecolumn IN ('x','y','z'))"
+ ")"
+ )
@testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''")
@testing.uses_deprecated('Manually quoting ENUM value literals')
diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py
index 4e9a324d4..626d54677 100644
--- a/test/dialect/test_postgresql.py
+++ b/test/dialect/test_postgresql.py
@@ -132,6 +132,25 @@ class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
postgresql.DropEnumType(e2),
"DROP TYPE someschema.somename"
)
+
+ t1 = Table('sometable', MetaData(), Column('somecolumn', e1))
+ self.assert_compile(
+ schema.CreateTable(t1),
+ "CREATE TABLE sometable ("
+ "somecolumn somename"
+ ")"
+ )
+ t1 = Table('sometable', MetaData(),
+ Column('somecolumn', Enum('x', 'y', 'z', native_enum=False))
+ )
+ self.assert_compile(
+ schema.CreateTable(t1),
+ "CREATE TABLE sometable ("
+ "somecolumn VARCHAR(1), "
+ " CHECK (somecolumn IN ('x','y','z'))"
+ ")"
+ )
+
@testing.fails_on('postgresql+zxjdbc',
'zxjdbc fails on ENUM: column "XXX" is of type XXX '
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index c844cf696..51dd4c12b 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -329,6 +329,87 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
assert uni(unicodedata) == unicodedata.encode('utf-8')
+class EnumTest(TestBase):
+ @classmethod
+ def setup_class(cls):
+ global enum_table, non_native_enum_table, metadata
+ metadata = MetaData(testing.db)
+ enum_table = Table('enum_table', metadata,
+ Column("id", Integer, primary_key=True),
+ Column('someenum', Enum('one','two','three', name='myenum'))
+ )
+
+ non_native_enum_table = Table('non_native_enum_table', metadata,
+ Column("id", Integer, primary_key=True),
+ Column('someenum', Enum('one','two','three', native_enum=False)),
+ )
+
+ metadata.create_all()
+
+ def teardown(self):
+ enum_table.delete().execute()
+ non_native_enum_table.delete().execute()
+
+ @classmethod
+ def teardown_class(cls):
+ metadata.drop_all()
+
+ @testing.fails_on('postgresql+zxjdbc',
+ 'zxjdbc fails on ENUM: column "XXX" is of type XXX '
+ 'but expression is of type character varying')
+ @testing.fails_on('postgresql+pg8000',
+ 'zxjdbc fails on ENUM: column "XXX" is of type XXX '
+ 'but expression is of type text')
+ def test_round_trip(self):
+ enum_table.insert().execute([
+ {'id':1, 'someenum':'two'},
+ {'id':2, 'someenum':'two'},
+ {'id':3, 'someenum':'one'},
+ ])
+
+ eq_(
+ enum_table.select().order_by(enum_table.c.id).execute().fetchall(),
+ [
+ (1, 'two'),
+ (2, 'two'),
+ (3, 'one'),
+ ]
+ )
+
+ def test_non_native_round_trip(self):
+ non_native_enum_table.insert().execute([
+ {'id':1, 'someenum':'two'},
+ {'id':2, 'someenum':'two'},
+ {'id':3, 'someenum':'one'},
+ ])
+
+ eq_(
+ non_native_enum_table.select().
+ order_by(non_native_enum_table.c.id).execute().fetchall(),
+ [
+ (1, 'two'),
+ (2, 'two'),
+ (3, 'one'),
+ ]
+ )
+
+ @testing.fails_on('postgresql+zxjdbc',
+ 'zxjdbc fails on ENUM: column "XXX" is of type XXX '
+ 'but expression is of type character varying')
+ @testing.fails_on('mysql', "MySQL seems to issue a 'data truncated' warning.")
+ def test_constraint(self):
+ assert_raises(exc.DBAPIError,
+ enum_table.insert().execute,
+ {'id':4, 'someenum':'four'}
+ )
+
+ @testing.fails_on('mysql', "the CHECK constraint doesn't raise an exception for unknown reason")
+ def test_non_native_constraint(self):
+ assert_raises(exc.DBAPIError,
+ non_native_enum_table.insert().execute,
+ {'id':4, 'someenum':'four'}
+ )
+
class BinaryTest(TestBase, AssertsExecutionResults):
__excluded_on__ = (
('mysql', '<', (4, 1, 1)), # screwy varbinary types