summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/postgres.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2005-12-16 07:18:27 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2005-12-16 07:18:27 +0000
commit6cdba110a49b701e36f93d82e8772d1909385175 (patch)
treec1b96732b7e8241db9d5a2262676db140d9f7030 /lib/sqlalchemy/databases/postgres.py
parent1f30247e22a4a3a14eb7f57261e289cc26e61bf3 (diff)
downloadsqlalchemy-6cdba110a49b701e36f93d82e8772d1909385175.tar.gz
factored "sequence" execution in postgres in oracle to be generalized to the SQLEngine, to also allow space for "defaults" that may be constants, python functions, or SQL functions/statements
Sequence schema object extends from a more generic "Default" object ANSICompiled can convert positinal params back to a dictionary, but the whole issue of parameters and how the engine executes compiled objects with parameters should be revisited mysql has fixes for its "rowid_column" being hidden else it screws up some query construction, also will not use AUTOINCREMENT unless the column is Integer
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r--lib/sqlalchemy/databases/postgres.py79
1 files changed, 28 insertions, 51 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 531e8af03..29590da0a 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -151,6 +151,9 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
def schemadropper(self, proxy, **params):
return PGSchemaDropper(proxy, **params)
+
+ def defaultrunner(self, proxy):
+ return PGDefaultRunner(proxy)
def get_default_schema_name(self):
if not hasattr(self, '_default_schema_name'):
@@ -158,50 +161,22 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
return self._default_schema_name
def last_inserted_ids(self):
- # if we used sequences or already had all values for the last inserted row,
- # return that list
- if self.context.last_inserted_ids is not None:
- return self.context.last_inserted_ids
-
- # else we have to use lastrowid and select the most recently inserted row
- table = self.context.last_inserted_table
- if self.context.lastrowid is not None and table is not None and len(table.primary_key):
- row = sql.select(table.primary_key, table.rowid_column == self.context.lastrowid).execute().fetchone()
- return [v for v in row]
- else:
- return None
-
- def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+ return self.context.last_inserted_ids
+
+ def pre_exec(self, proxy, statement, parameters, **kwargs):
+ return
+
+ def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
if compiled is None: return
- if getattr(compiled, "isinsert", False):
- if isinstance(parameters, list):
- plist = parameters
- else:
- plist = [parameters]
- # inserts are usually one at a time. but if we got a list of parameters,
- # it will calculate last_inserted_ids for just the last row in the list.
- # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
- # it or post-select anyway
- for param in plist:
- last_inserted_ids = []
- need_lastrowid=False
- for primary_key in compiled.statement.table.primary_key:
- if not param.has_key(primary_key.key) or param[primary_key.key] is None:
- if primary_key.sequence is not None and not primary_key.sequence.optional:
- if echo is True or self.echo:
- self.log("select nextval('%s')" % primary_key.sequence.name)
- cursor.execute("select nextval('%s')" % primary_key.sequence.name)
- newid = cursor.fetchone()[0]
- param[primary_key.key] = newid
- last_inserted_ids.append(param[primary_key.key])
- else:
- need_lastrowid = True
- else:
- last_inserted_ids.append(param[primary_key.key])
- if need_lastrowid:
- self.context.last_inserted_ids = None
- else:
- self.context.last_inserted_ids = last_inserted_ids
+ if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None:
+ table = compiled.statement.table
+ cursor = proxy()
+ if cursor.lastrowid is not None and table is not None and len(table.primary_key):
+ s = sql.select(table.primary_key, table.rowid_column == cursor.lastrowid)
+ c = s.compile()
+ cursor = proxy(str(c), c.get_params())
+ row = cursor.fetchone()
+ self.context.last_inserted_ids = [v for v in row]
def _executemany(self, c, statement, parameters):
"""we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough
@@ -212,12 +187,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
rowcount += c.rowcount
self.context.rowcount = rowcount
- def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
- if compiled is None: return
- if getattr(compiled, "isinsert", False):
- table = compiled.statement.table
- self.context.last_inserted_table = table
- self.context.lastrowid = cursor.lastrowid
def dbapi(self):
return self.module
@@ -237,7 +206,7 @@ class PGCompiler(ansisql.ANSICompiler):
with autoincrement fields that require they not be present. so,
put them all in for columns where sequence usage is defined."""
for c in insert.table.primary_key:
- if c.sequence is not None and not c.sequence.optional:
+ if self.bindparams.get(c.key, None) is None and c.default is not None and not c.default.optional:
self.bindparams[c.key] = None
return ansisql.ANSICompiler.visit_insert(self, insert)
@@ -254,7 +223,7 @@ class PGCompiler(ansisql.ANSICompiler):
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional):
+ if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or column.default.optional):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
@@ -277,3 +246,11 @@ class PGSchemaDropper(ansisql.ANSISchemaDropper):
if not sequence.optional:
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
+
+class PGDefaultRunner(ansisql.ANSIDefaultRunner):
+ def visit_sequence(self, seq):
+ if not seq.optional:
+ c = self.proxy("select nextval('%s')" % seq.name)
+ return c.fetchone()[0]
+ else:
+ return None \ No newline at end of file