summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r--lib/sqlalchemy/schema.py41
1 files changed, 27 insertions, 14 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 60c42c25a..606bcf508 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -59,6 +59,7 @@ def _get_table_key(engine, name, schema):
class TableSingleton(type):
def __call__(self, name, engine, *args, **kwargs):
try:
+ name = str(name) # in case of incoming unicode
schema = kwargs.get('schema', None)
autoload = kwargs.pop('autoload', False)
redefine = kwargs.pop('redefine', False)
@@ -151,15 +152,15 @@ class Table(SchemaItem):
class Column(SchemaItem):
"""represents a column in a database table."""
def __init__(self, name, type, *args, **kwargs):
- self.name = name
+ self.name = str(name) # in case of incoming unicode
self.type = type
self.args = args
self.key = kwargs.pop('key', name)
self.primary_key = kwargs.pop('primary_key', False)
self.nullable = kwargs.pop('nullable', not self.primary_key)
self.hidden = kwargs.pop('hidden', False)
+ self.default = kwargs.pop('default', None)
self.foreign_key = None
- self.sequence = None
self._orig = None
if len(kwargs):
raise "Unknown arguments passed to Column: " + repr(kwargs.keys())
@@ -185,6 +186,8 @@ class Column(SchemaItem):
self._impl = self.table.engine.columnimpl(self)
+ if self.default is not None:
+ self._init_items(self.default)
self._init_items(*self.args)
self.args = None
@@ -194,7 +197,7 @@ class Column(SchemaItem):
fk = None
else:
fk = self.foreign_key.copy()
- return Column(self.name, self.type, fk, self.sequence, key = self.key, primary_key = self.primary_key)
+ return Column(self.name, self.type, fk, self.default, key = self.key, primary_key = self.primary_key)
def _make_proxy(self, selectable, name = None):
"""creates a copy of this Column, initialized the way this Column is"""
@@ -202,7 +205,7 @@ class Column(SchemaItem):
fk = None
else:
fk = self.foreign_key.copy()
- c = Column(name or self.name, self.type, fk, self.sequence, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden)
+ c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden)
c.table = selectable
c._orig = self.original
if not c.hidden:
@@ -211,8 +214,8 @@ class Column(SchemaItem):
return c
def accept_visitor(self, visitor):
- if self.sequence is not None:
- self.sequence.accept_visitor(visitor)
+ if self.default is not None:
+ self.default.accept_visitor(visitor)
if self.foreign_key is not None:
self.foreign_key.accept_visitor(visitor)
visitor.visit_column(self)
@@ -280,23 +283,32 @@ class ForeignKey(SchemaItem):
visitor.visit_foreign_key(self)
def _set_parent(self, column):
- if not isinstance(column, Column):
- raise "hi" + repr(type(column))
self.parent = column
self.parent.foreign_key = self
self.parent.table.foreign_keys.append(self)
+
+class DefaultGenerator(SchemaItem):
+ """represents a "default value generator" for a particular column in a particular
+ table. This could correspond to a constant, a callable function, or a SQL clause."""
+ def _set_parent(self, column):
+ self.column = column
+ self.column.default = self
+ def accept_visitor(self, visitor):
+ pass
+
+class ColumnDefault(DefaultGenerator):
+ def __init__(self, arg):
+ self.arg = arg
+ def accept_visitor(self, visitor):
+ return visitor.visit_column_default(self)
-class Sequence(SchemaItem):
+class Sequence(DefaultGenerator):
"""represents a sequence, which applies to Oracle and Postgres databases."""
- def __init__(self, name, func = None, start = None, increment = None, optional=False):
+ def __init__(self, name, start = None, increment = None, optional=False):
self.name = name
- self.func = func
self.start = start
self.increment = increment
self.optional=optional
- def _set_parent(self, column):
- self.column = column
- self.column.sequence = self
def accept_visitor(self, visitor):
return visitor.visit_sequence(self)
@@ -317,6 +329,7 @@ class SchemaVisitor(object):
def visit_column(self, column):pass
def visit_foreign_key(self, join):pass
def visit_index(self, index):pass
+ def visit_column_default(self, default):pass
def visit_sequence(self, sequence):pass