diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-12-16 07:18:27 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-12-16 07:18:27 +0000 |
commit | 6cdba110a49b701e36f93d82e8772d1909385175 (patch) | |
tree | c1b96732b7e8241db9d5a2262676db140d9f7030 /lib/sqlalchemy/databases/postgres.py | |
parent | 1f30247e22a4a3a14eb7f57261e289cc26e61bf3 (diff) | |
download | sqlalchemy-6cdba110a49b701e36f93d82e8772d1909385175.tar.gz |
factored "sequence" execution in postgres in oracle to be generalized to the SQLEngine, to also allow space for "defaults" that may be constants, python functions, or SQL functions/statements
Sequence schema object extends from a more generic "Default" object
ANSICompiled can convert positinal params back to a dictionary, but the whole issue of parameters and how the engine executes compiled objects with parameters should be revisited
mysql has fixes for its "rowid_column" being hidden else it screws up some query construction, also will not use AUTOINCREMENT unless the column is Integer
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 79 |
1 files changed, 28 insertions, 51 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 531e8af03..29590da0a 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -151,6 +151,9 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def schemadropper(self, proxy, **params): return PGSchemaDropper(proxy, **params) + + def defaultrunner(self, proxy): + return PGDefaultRunner(proxy) def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): @@ -158,50 +161,22 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return self._default_schema_name def last_inserted_ids(self): - # if we used sequences or already had all values for the last inserted row, - # return that list - if self.context.last_inserted_ids is not None: - return self.context.last_inserted_ids - - # else we have to use lastrowid and select the most recently inserted row - table = self.context.last_inserted_table - if self.context.lastrowid is not None and table is not None and len(table.primary_key): - row = sql.select(table.primary_key, table.rowid_column == self.context.lastrowid).execute().fetchone() - return [v for v in row] - else: - return None - - def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): + return self.context.last_inserted_ids + + def pre_exec(self, proxy, statement, parameters, **kwargs): + return + + def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs): if compiled is None: return - if getattr(compiled, "isinsert", False): - if isinstance(parameters, list): - plist = parameters - else: - plist = [parameters] - # inserts are usually one at a time. but if we got a list of parameters, - # it will calculate last_inserted_ids for just the last row in the list. - # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence - # it or post-select anyway - for param in plist: - last_inserted_ids = [] - need_lastrowid=False - for primary_key in compiled.statement.table.primary_key: - if not param.has_key(primary_key.key) or param[primary_key.key] is None: - if primary_key.sequence is not None and not primary_key.sequence.optional: - if echo is True or self.echo: - self.log("select nextval('%s')" % primary_key.sequence.name) - cursor.execute("select nextval('%s')" % primary_key.sequence.name) - newid = cursor.fetchone()[0] - param[primary_key.key] = newid - last_inserted_ids.append(param[primary_key.key]) - else: - need_lastrowid = True - else: - last_inserted_ids.append(param[primary_key.key]) - if need_lastrowid: - self.context.last_inserted_ids = None - else: - self.context.last_inserted_ids = last_inserted_ids + if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None: + table = compiled.statement.table + cursor = proxy() + if cursor.lastrowid is not None and table is not None and len(table.primary_key): + s = sql.select(table.primary_key, table.rowid_column == cursor.lastrowid) + c = s.compile() + cursor = proxy(str(c), c.get_params()) + row = cursor.fetchone() + self.context.last_inserted_ids = [v for v in row] def _executemany(self, c, statement, parameters): """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough @@ -212,12 +187,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine): rowcount += c.rowcount self.context.rowcount = rowcount - def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): - if compiled is None: return - if getattr(compiled, "isinsert", False): - table = compiled.statement.table - self.context.last_inserted_table = table - self.context.lastrowid = cursor.lastrowid def dbapi(self): return self.module @@ -237,7 +206,7 @@ class PGCompiler(ansisql.ANSICompiler): with autoincrement fields that require they not be present. so, put them all in for columns where sequence usage is defined.""" for c in insert.table.primary_key: - if c.sequence is not None and not c.sequence.optional: + if self.bindparams.get(c.key, None) is None and c.default is not None and not c.default.optional: self.bindparams[c.key] = None return ansisql.ANSICompiler.visit_insert(self, insert) @@ -254,7 +223,7 @@ class PGCompiler(ansisql.ANSICompiler): class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional): + if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or column.default.optional): colspec += " SERIAL" else: colspec += " " + column.type.get_col_spec() @@ -277,3 +246,11 @@ class PGSchemaDropper(ansisql.ANSISchemaDropper): if not sequence.optional: self.append("DROP SEQUENCE %s" % sequence.name) self.execute() + +class PGDefaultRunner(ansisql.ANSIDefaultRunner): + def visit_sequence(self, seq): + if not seq.optional: + c = self.proxy("select nextval('%s')" % seq.name) + return c.fetchone()[0] + else: + return None
\ No newline at end of file |