diff options
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r-- | lib/sqlalchemy/schema.py | 41 |
1 files changed, 27 insertions, 14 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 60c42c25a..606bcf508 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -59,6 +59,7 @@ def _get_table_key(engine, name, schema): class TableSingleton(type): def __call__(self, name, engine, *args, **kwargs): try: + name = str(name) # in case of incoming unicode schema = kwargs.get('schema', None) autoload = kwargs.pop('autoload', False) redefine = kwargs.pop('redefine', False) @@ -151,15 +152,15 @@ class Table(SchemaItem): class Column(SchemaItem): """represents a column in a database table.""" def __init__(self, name, type, *args, **kwargs): - self.name = name + self.name = str(name) # in case of incoming unicode self.type = type self.args = args self.key = kwargs.pop('key', name) self.primary_key = kwargs.pop('primary_key', False) self.nullable = kwargs.pop('nullable', not self.primary_key) self.hidden = kwargs.pop('hidden', False) + self.default = kwargs.pop('default', None) self.foreign_key = None - self.sequence = None self._orig = None if len(kwargs): raise "Unknown arguments passed to Column: " + repr(kwargs.keys()) @@ -185,6 +186,8 @@ class Column(SchemaItem): self._impl = self.table.engine.columnimpl(self) + if self.default is not None: + self._init_items(self.default) self._init_items(*self.args) self.args = None @@ -194,7 +197,7 @@ class Column(SchemaItem): fk = None else: fk = self.foreign_key.copy() - return Column(self.name, self.type, fk, self.sequence, key = self.key, primary_key = self.primary_key) + return Column(self.name, self.type, fk, self.default, key = self.key, primary_key = self.primary_key) def _make_proxy(self, selectable, name = None): """creates a copy of this Column, initialized the way this Column is""" @@ -202,7 +205,7 @@ class Column(SchemaItem): fk = None else: fk = self.foreign_key.copy() - c = Column(name or self.name, self.type, fk, self.sequence, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden) + c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden) c.table = selectable c._orig = self.original if not c.hidden: @@ -211,8 +214,8 @@ class Column(SchemaItem): return c def accept_visitor(self, visitor): - if self.sequence is not None: - self.sequence.accept_visitor(visitor) + if self.default is not None: + self.default.accept_visitor(visitor) if self.foreign_key is not None: self.foreign_key.accept_visitor(visitor) visitor.visit_column(self) @@ -280,23 +283,32 @@ class ForeignKey(SchemaItem): visitor.visit_foreign_key(self) def _set_parent(self, column): - if not isinstance(column, Column): - raise "hi" + repr(type(column)) self.parent = column self.parent.foreign_key = self self.parent.table.foreign_keys.append(self) + +class DefaultGenerator(SchemaItem): + """represents a "default value generator" for a particular column in a particular + table. This could correspond to a constant, a callable function, or a SQL clause.""" + def _set_parent(self, column): + self.column = column + self.column.default = self + def accept_visitor(self, visitor): + pass + +class ColumnDefault(DefaultGenerator): + def __init__(self, arg): + self.arg = arg + def accept_visitor(self, visitor): + return visitor.visit_column_default(self) -class Sequence(SchemaItem): +class Sequence(DefaultGenerator): """represents a sequence, which applies to Oracle and Postgres databases.""" - def __init__(self, name, func = None, start = None, increment = None, optional=False): + def __init__(self, name, start = None, increment = None, optional=False): self.name = name - self.func = func self.start = start self.increment = increment self.optional=optional - def _set_parent(self, column): - self.column = column - self.column.sequence = self def accept_visitor(self, visitor): return visitor.visit_sequence(self) @@ -317,6 +329,7 @@ class SchemaVisitor(object): def visit_column(self, column):pass def visit_foreign_key(self, join):pass def visit_index(self, index):pass + def visit_column_default(self, default):pass def visit_sequence(self, sequence):pass |