summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-08-21 04:38:51 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-08-21 04:38:51 +0000
commit4ca1de17832b49423dd448f7becc600a459e8cbf (patch)
treea1700ec3a0937efe9f898d9892a7312dee3341c9
parent435f7b1c381c6f8af34c1ca97d42e333d2b40f7c (diff)
downloadsqlalchemy-4ca1de17832b49423dd448f7becc600a459e8cbf.tar.gz
working on sequence quoting support....
-rw-r--r--lib/sqlalchemy/ansisql.py19
-rw-r--r--lib/sqlalchemy/databases/postgres.py4
-rw-r--r--lib/sqlalchemy/schema.py24
3 files changed, 37 insertions, 10 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 53c6db6c4..f77e855e4 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -774,7 +774,15 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
self.__strings[column] = self._quote_identifier(column.name)
else:
self.__strings[column] = column.name
-
+
+ def visit_sequence(self, sequence):
+ if sequence in self.__visited:
+ return
+ if sequence.quote or self._requires_quotes(sequence.name, sequence.natural_case):
+ self.__strings[sequence] = self._quote_identifier(sequence.name)
+ else:
+ self.__strings[sequence] = sequence.name
+
def __analyze_identifiers(self, obj):
"""insure that each object we encounter is analyzed only once for its lifetime."""
if obj in self.__visited:
@@ -782,7 +790,11 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
if isinstance(obj, schema.SchemaItem):
obj.accept_schema_visitor(self)
self.__visited[obj] = True
-
+
+ def __prepare_sequence(self, sequence):
+ self.__analyze_identifiers(sequence)
+ return self.__strings.get(sequence, sequence.name)
+
def __prepare_table(self, table, use_schema=False):
self.__analyze_identifiers(table)
tablename = self.__strings.get(table, (table.name, None))[0]
@@ -798,6 +810,9 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
else:
return self.__strings.get(column, column.name)
+ def format_sequence(self, sequence):
+ return self.__prepare_sequence(sequence)
+
def format_table(self, table, use_schema=True):
"""Prepare a quoted table and schema name"""
return self.__prepare_table(table, use_schema=use_schema)
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 887c593e1..1bcf83409 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -530,6 +530,8 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
return c.fetchone()[0]
elif isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
+ # TODO: this has to build into the Sequence object so we can get the quoting
+ # logic from it
if sch is not None:
exc = "select nextval('%s.%s_%s_seq')" % (sch, column.table.name, column.name)
else:
@@ -543,7 +545,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
def visit_sequence(self, seq):
if not seq.optional:
- c = self.proxy("select nextval('%s')" % seq.name)
+ c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq))
return c.fetchone()[0]
else:
return None
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 2bf1627dd..169308445 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -169,8 +169,11 @@ class Table(SchemaItem, sql.TableClause):
self.owner = kwargs.pop('owner', None)
self.quote = kwargs.pop('quote', False)
self.quote_schema = kwargs.pop('quote_schema', False)
- self.natural_case = kwargs.pop('natural_case', True)
- self.natural_case_schema = kwargs.pop('natural_case_schema', True)
+ default_natural_case = metadata.natural_case
+ if default_natural_case is None:
+ default_natural_case = True
+ self.natural_case = kwargs.pop('natural_case', default_natural_case)
+ self.natural_case_schema = kwargs.pop('natural_case_schema', default_natural_case)
self.kwargs = kwargs
def _set_primary_key(self, pk):
@@ -403,6 +406,8 @@ class Column(SchemaItem, sql.ColumnClause):
if getattr(self, 'table', None) is not None:
raise exceptions.ArgumentError("this Column already has a table!")
table.append_column(self)
+ if self.table.metadata.natural_case is not None:
+ self.natural_case = self.table.metadata.natural_case
if self.index or self.unique:
table.append_index_column(self, index=self.index,
unique=self.unique)
@@ -595,12 +600,14 @@ class ColumnDefault(DefaultGenerator):
class Sequence(DefaultGenerator):
"""represents a sequence, which applies to Oracle and Postgres databases."""
- def __init__(self, name, start = None, increment = None, optional=False, **kwargs):
+ def __init__(self, name, start = None, increment = None, optional=False, quote=False, natural_case=True, **kwargs):
super(Sequence, self).__init__(**kwargs)
self.name = name
self.start = start
self.increment = increment
self.optional=optional
+ self.natural_case = natural_case
+ self.quote = quote
def __repr__(self):
return "Sequence(%s)" % string.join(
[repr(self.name)] +
@@ -609,6 +616,8 @@ class Sequence(DefaultGenerator):
def _set_parent(self, column):
super(Sequence, self)._set_parent(column)
column.sequence = self
+ if column.metadata.natural_case is not None:
+ self.natural_case = column.metadata.natural_case
def create(self):
self.engine.create(self)
return self
@@ -763,10 +772,11 @@ class Index(SchemaItem):
class MetaData(SchemaItem):
"""represents a collection of Tables and their associated schema constructs."""
- def __init__(self, name=None):
+ def __init__(self, name=None, natural_case=None, **kwargs):
# a dictionary that stores Table objects keyed off their name (and possibly schema name)
self.tables = {}
self.name = name
+ self.natural_case = natural_case
def is_bound(self):
return False
def clear(self):
@@ -850,7 +860,7 @@ class MetaData(SchemaItem):
class BoundMetaData(MetaData):
"""builds upon MetaData to provide the capability to bind to an Engine implementation."""
def __init__(self, engine_or_url, name=None, **kwargs):
- super(BoundMetaData, self).__init__(name)
+ super(BoundMetaData, self).__init__(name, **kwargs)
if isinstance(engine_or_url, str):
self._engine = sqlalchemy.create_engine(engine_or_url, **kwargs)
else:
@@ -861,8 +871,8 @@ class BoundMetaData(MetaData):
class DynamicMetaData(MetaData):
"""builds upon MetaData to provide the capability to bind to multiple Engine implementations
on a dynamically alterable, thread-local basis."""
- def __init__(self, name=None, threadlocal=True):
- super(DynamicMetaData, self).__init__(name)
+ def __init__(self, name=None, threadlocal=True, **kwargs):
+ super(DynamicMetaData, self).__init__(name, **kwargs)
if threadlocal:
self.context = util.ThreadLocal()
else: