diff options
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 48 |
1 files changed, 20 insertions, 28 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 74a3ef13f..29d84ad4d 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,11 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, random, warnings, operator +import re, random, warnings -from sqlalchemy import sql, schema, ansisql, exceptions, util +from sqlalchemy import sql, schema, exceptions, util from sqlalchemy.engine import base, default -import sqlalchemy.types as sqltypes +from sqlalchemy.sql import compiler +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes class PGInet(sqltypes.TypeEngine): @@ -220,9 +222,9 @@ class PGExecutionContext(default.DefaultExecutionContext): self._last_inserted_ids = [v for v in row] super(PGExecutionContext, self).post_exec() -class PGDialect(ansisql.ANSIDialect): +class PGDialect(default.DefaultDialect): def __init__(self, use_oids=False, server_side_cursors=False, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) + default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors self.paramstyle = 'pyformat' @@ -249,15 +251,6 @@ class PGDialect(ansisql.ANSIDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def compiler(self, statement, bindparams, **kwargs): - return PGCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return PGSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return PGSchemaDropper(self, *args, **kwargs) - def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) @@ -286,12 +279,6 @@ class PGDialect(ansisql.ANSIDialect): resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] - def defaultrunner(self, context, **kwargs): - return PGDefaultRunner(context, **kwargs) - - def preparer(self): - return PGIdentifierPreparer(self) - def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): self._default_schema_name = connection.scalar("select current_schema()", None) @@ -556,11 +543,11 @@ class PGDialect(ansisql.ANSIDialect): -class PGCompiler(ansisql.ANSICompiler): - operators = ansisql.ANSICompiler.operators.copy() +class PGCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() operators.update( { - operator.mod : '%%' + sql_operators.mod : '%%' } ) @@ -597,7 +584,7 @@ class PGCompiler(ansisql.ANSICompiler): else: return super(PGCompiler, self).for_update_clause(select) -class PGSchemaGenerator(ansisql.ANSISchemaGenerator): +class PGSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): @@ -620,13 +607,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGSchemaDropper(ansisql.ANSISchemaDropper): +class PGSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGDefaultRunner(ansisql.ANSIDefaultRunner): +class PGDefaultRunner(base.DefaultRunner): def get_column_default(self, column, isinsert=True): if column.primary_key: # passive defaults on primary keys have to be overridden @@ -642,7 +629,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) return self.connection.execute(exc).scalar() - return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) + return super(PGDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: @@ -650,7 +637,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): else: return None -class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class PGIdentifierPreparer(compiler.IdentifierPreparer): def _fold_identifier_case(self, value): return value.lower() @@ -660,3 +647,8 @@ class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): return value dialect = PGDialect +dialect.statement_compiler = PGCompiler +dialect.schemagenerator = PGSchemaGenerator +dialect.schemadropper = PGSchemaDropper +dialect.preparer = PGIdentifierPreparer +dialect.defaultrunner = PGDefaultRunner |