diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 69 |
1 files changed, 34 insertions, 35 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 19d9224e2..0bc5f08b0 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -332,40 +332,6 @@ class PGDDLCompiler(compiler.DDLCompiler): return text -class PGDefaultRunner(base.DefaultRunner): - - def get_column_default(self, column, isinsert=True): - if column.primary_key: - if (isinstance(column.server_default, schema.DefaultClause) and - column.server_default.arg is not None): - - # pre-execute passive defaults on primary key columns - return self.execute_string("select %s" % column.server_default.arg) - - elif column is column.table._autoincrement_column \ - and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - - # execute the sequence associated with a SERIAL primary key column. - # for non-primary-key SERIAL, the ID just generates server side. - sch = column.table.schema - - if sch is not None: - exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) - else: - exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - - return self.execute_string(exc) - - return super(PGDefaultRunner, self).get_column_default(column) - - def visit_sequence(self, seq): - if not seq.optional: - return self.execute_string(("select nextval('%s')" % \ - self.dialect.identifier_preparer.format_sequence(seq))) - else: - return None - - class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_INET(self, type_): return "INET" @@ -438,6 +404,39 @@ class PGInspector(reflection.Inspector): info_cache=self.info_cache) + +class PGExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq): + if not seq.optional: + return self._execute_scalar(("select nextval('%s')" % \ + self.dialect.identifier_preparer.format_sequence(seq))) + else: + return None + + def get_insert_default(self, column): + if column.primary_key: + if (isinstance(column.server_default, schema.DefaultClause) and + column.server_default.arg is not None): + + # pre-execute passive defaults on primary key columns + return self._execute_scalar("select %s" % column.server_default.arg) + + elif column is column.table._autoincrement_column \ + and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + + # execute the sequence associated with a SERIAL primary key column. + # for non-primary-key SERIAL, the ID just generates server side. + sch = column.table.schema + + if sch is not None: + exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) + else: + exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) + + return self._execute_scalar(exc) + + return super(PGExecutionContext, self).get_insert_default(column) + class PGDialect(default.DefaultDialect): name = 'postgresql' supports_alter = True @@ -459,7 +458,7 @@ class PGDialect(default.DefaultDialect): ddl_compiler = PGDDLCompiler type_compiler = PGTypeCompiler preparer = PGIdentifierPreparer - defaultrunner = PGDefaultRunner + execution_ctx_cls = PGExecutionContext inspector = PGInspector isolation_level = None |