summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-02-05 00:19:14 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-02-05 00:19:14 +0000
commit1d20ecbb6d0f0f8cfbb0a54e2a3aaf6cead23ecb (patch)
tree7092c12ed6a625137361ad35448936197a2d3ff6
parentf0aa20cab011488b1cdecb5ed9bc68fc1ef1f73e (diff)
downloadsqlalchemy-1d20ecbb6d0f0f8cfbb0a54e2a3aaf6cead23ecb.tar.gz
started PassiveDefault, which is a "database-side" default. mapper will go
fetch the most recently inserted row if a table has PassiveDefault's set on it
-rw-r--r--lib/sqlalchemy/databases/postgres.py4
-rw-r--r--lib/sqlalchemy/engine.py12
-rw-r--r--lib/sqlalchemy/mapping/mapper.py8
-rw-r--r--lib/sqlalchemy/schema.py12
4 files changed, 34 insertions, 2 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 2d6adc165..9122c2afa 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -275,7 +275,9 @@ class PGCompiler(ansisql.ANSICompiler):
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ if isinstance(column.default, schema.PassiveDefault):
+ colspec += " DEFAULT " + column.default.text
+ elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py
index d99e5eb6c..b00d97de0 100644
--- a/lib/sqlalchemy/engine.py
+++ b/lib/sqlalchemy/engine.py
@@ -135,6 +135,11 @@ class DefaultRunner(schema.SchemaVisitor):
else:
return None
+ def visit_passive_default(self, default):
+ """passive defaults by definition return None on the app side,
+ and are post-fetched to get the DB-side value"""
+ return None
+
def visit_sequence(self, seq):
"""sequences are not supported by default"""
return None
@@ -452,10 +457,13 @@ class SQLEngine(schema.SchemaEngine):
else:
plist = [parameters]
drunner = self.defaultrunner(proxy)
+ self.context.lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
need_lastrowid=False
for c in compiled.statement.table.c:
+ if isinstance(c.default, schema.PassiveDefault):
+ self.context.lastrow_has_defaults = True
if not param.has_key(c.key) or param[c.key] is None:
newid = drunner.get_column_default(c)
if newid is not None:
@@ -471,7 +479,9 @@ class SQLEngine(schema.SchemaEngine):
else:
self.context.last_inserted_ids = last_inserted_ids
-
+ def lastrow_has_defaults(self):
+ return self.context.lastrow_has_defaults
+
def pre_exec(self, proxy, compiled, parameters, **kwargs):
"""called by execute_compiled before the compiled statement is executed."""
pass
diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py
index 7e11f5ebe..4516ae7b3 100644
--- a/lib/sqlalchemy/mapping/mapper.py
+++ b/lib/sqlalchemy/mapping/mapper.py
@@ -578,6 +578,14 @@ class Mapper(object):
if self._getattrbycolumn(obj, col) is None:
self._setattrbycolumn(obj, col, primary_key[i])
i+=1
+ if table.engine.lastrow_has_defaults():
+ clause = sql.and_()
+ for p in self.pks_by_table[table]:
+ clause.clauses.append(p == self._getattrbycolumn(obj, p))
+ row = table.select(clause).execute().fetchone()
+ for c in table.c:
+ if self._getattrbycolumn(obj, col) is None:
+ self._setattrbycolumn(obj, col, row[c])
self.extension.after_insert(self, obj)
def delete_obj(self, objects, uow):
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index a5e6e0777..de672dc9e 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -417,6 +417,15 @@ class DefaultGenerator(SchemaItem):
self.column.default = self
def __repr__(self):
return "DefaultGenerator()"
+
+class PassiveDefault(DefaultGenerator):
+ """a default that takes effect on the database side"""
+ def __init__(self, text):
+ self.text = text
+ def accept_visitor(self, visitor):
+ return visitor_visit_passive_default(self)
+ def __repr__(self):
+ return "PassiveDefault(%s)" % repr(self.text)
class ColumnDefault(DefaultGenerator):
"""A plain default value on a column. this could correspond to a constant,
@@ -477,6 +486,9 @@ class SchemaVisitor(object):
def visit_index(self, index):
"""visit an Index (not implemented yet)."""
pass
+ def visit_passive_default(self, default):
+ """visit a passive default"""
+ pass
def visit_column_default(self, default):
"""visit a ColumnDefault."""
pass