diff options
-rw-r--r-- | CHANGES | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 87 | ||||
-rw-r--r-- | test/dialect/firebird.py | 109 |
3 files changed, 197 insertions, 4 deletions
@@ -66,6 +66,9 @@ user_defined_state incompatible; previously the extensions of last mapper defined would receive these events. +- firebird + - Added support for returning values from inserts (2.0+ only), + updates and deletes (2.1+ only). 0.4.6 ===== @@ -1573,7 +1576,7 @@ user_defined_state - The no-arg ResultProxy._row_processor() is now the class attribute `_process_row`. -- Added support for returning values from inserts and udpates for +- Added support for returning values from inserts and updates for PostgreSQL 8.2+. [ticket:797] - PG reflection, upon seeing the default schema name being used explicitly diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 948e001d5..d3662ccbf 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -78,6 +78,18 @@ connections are active, the following setting may alleviate the problem:: # Force SA to use a single connection per thread dialect.poolclass = pool.SingletonThreadPool +RETURNING support +----------------- + +Firebird 2.0 supports returning a result set from inserts, and 2.1 extends +that to deletes and updates. + +To use this pass the column/expression list to the ``firebird_returning`` +parameter when creating the queries:: + + raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), + firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall() + .. [#] Well, that is not the whole story, as the client may still ask a different (lower) dialect... @@ -87,7 +99,7 @@ connections are active, the following setting may alleviate the problem:: """ -import datetime +import datetime, re from sqlalchemy import exc, schema, types as sqltypes, sql, util from sqlalchemy.engine import base, default @@ -261,8 +273,44 @@ def descriptor(): ]} +SELECT_RE = re.compile( + r'\s*(?:SELECT|(UPDATE|INSERT|DELETE))', + re.I | re.UNICODE) + +RETURNING_RE = re.compile( + 'RETURNING', + re.I | re.UNICODE) + +# This finds if the RETURNING is not inside a quoted/commented values. Handles string literals, +# quoted identifiers, dollar quotes, SQL comments and C style multiline comments. This does not +# handle correctly nested C style quotes, lets hope no one does the following: +# UPDATE tbl SET x=y /* foo /* bar */ RETURNING */ +RETURNING_QUOTED_RE = re.compile( + """\s*(?:UPDATE|INSERT|DELETE)\s + (?: # handle quoted and commented tokens separately + [^'"$/-] # non quote/comment character + | -(?!-) # a dash that does not begin a comment + | /(?!\*) # a slash that does not begin a comment + | "(?:[^"]|"")*" # quoted literal + | '(?:[^']|'')*' # quoted string + | --[^\\n]*(?=\\n) # SQL comment, leave out line ending as that counts as whitespace + # for the returning token + | /\*([^*]|\*(?!/))*\*/ # C style comment, doesn't handle nesting + )* + \sRETURNING\s""", re.I | re.UNICODE | re.VERBOSE) + +RETURNING_KW_NAME = 'firebird_returning' + class FBExecutionContext(default.DefaultExecutionContext): - pass + def returns_rows_text(self, statement): + m = SELECT_RE.match(statement) + return m and (not m.group(1) or (RETURNING_RE.search(statement) + and RETURNING_QUOTED_RE.match(statement))) + + def returns_rows_compiled(self, compiled): + return (isinstance(compiled.statement, sql.expression.Selectable) or + ((compiled.isupdate or compiled.isinsert or compiler.isdelete) and + RETURNING_KW_NAME in compiled.statement.kwargs)) class FBDialect(default.DefaultDialect): @@ -629,6 +677,41 @@ class FBCompiler(sql.compiler.DefaultCompiler): return self.LENGTH_FUNCTION_NAME + '%(expr)s' return super(FBCompiler, self).function_string(func) + def _append_returning(self, text, stmt): + returning_cols = stmt.kwargs[RETURNING_KW_NAME] + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, sql.expression.Selectable): + for co in c.columns: + yield co + else: + yield c + columns = [self.process(c, render_labels=True) + for c in flatten_columnlist(returning_cols)] + text += ' RETURNING ' + ', '.join(columns) + return text + + def visit_update(self, update_stmt): + text = super(FBCompiler, self).visit_update(update_stmt) + if RETURNING_KW_NAME in update_stmt.kwargs: + return self._append_returning(text, update_stmt) + else: + return text + + def visit_insert(self, insert_stmt): + text = super(FBCompiler, self).visit_insert(insert_stmt) + if RETURNING_KW_NAME in insert_stmt.kwargs: + return self._append_returning(text, insert_stmt) + else: + return text + + def visit_delete(self, delete_stmt): + text = super(FBCompiler, self).visit_delete(delete_stmt) + if RETURNING_KW_NAME in delete_stmt.kwargs: + return self._append_returning(text, delete_stmt) + else: + return text + class FBSchemaGenerator(sql.compiler.SchemaGenerator): """Firebird syntactic idiosincrasies""" diff --git a/test/dialect/firebird.py b/test/dialect/firebird.py index da6cc6970..4b4b9fd7f 100644 --- a/test/dialect/firebird.py +++ b/test/dialect/firebird.py @@ -89,8 +89,114 @@ class CompileTest(TestBase, AssertsCompiledSQL): self.assert_compile(func.substring('abc', 1, 2), "SUBSTRING(:substring_1 FROM :substring_2 FOR :substring_3)") self.assert_compile(func.substring('abc', 1), "SUBSTRING(:substring_1 FROM :substring_2)") -class MiscFBTests(TestBase): + def test_update_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name") + + u = update(table1, values=dict(name='foo'), firebird_returning=[table1]) + self.assert_compile(u, "UPDATE mytable SET name=:name "\ + "RETURNING mytable.myid, mytable.name, mytable.description") + + u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) + self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)") + + def test_insert_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name") + + i = insert(table1, values=dict(name='foo'), firebird_returning=[table1]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\ + "RETURNING mytable.myid, mytable.name, mytable.description") + + i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)") + + +class ReturningTest(TestBase, AssertsExecutionResults): + __only_on__ = 'firebird' + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + def test_update_returning(self): + meta = MetaData(testing.db) + table = Table('tables', meta, + Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), + Column('persons', Integer), + Column('full', Boolean) + ) + table.create() + try: + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute() + self.assertEqual(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + self.assertEqual(result2.fetchall(), [(1,True),(2,False)]) + finally: + table.drop() + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + def test_insert_returning(self): + meta = MetaData(testing.db) + table = Table('tables', meta, + Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), + Column('persons', Integer), + Column('full', Boolean) + ) + table.create() + try: + result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False}) + + self.assertEqual(result.fetchall(), [(1,)]) + + # Multiple inserts only return the last row + result2 = table.insert(firebird_returning=[table]).execute( + [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) + + self.assertEqual(result2.fetchall(), [(3,3,True)]) + + result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False}) + self.assertEqual([dict(row) for row in result3], [{'ID':4}]) + + result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons') + self.assertEqual([dict(row) for row in result4], [{'PERSONS': 10}]) + finally: + table.drop() + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + def test_delete_returning(self): + meta = MetaData(testing.db) + table = Table('tables', meta, + Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), + Column('persons', Integer), + Column('full', Boolean) + ) + table.create() + try: + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.delete(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute() + self.assertEqual(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + self.assertEqual(result2.fetchall(), [(2,False),]) + finally: + table.drop() + + +class MiscFBTests(TestBase): __only_on__ = 'firebird' def test_strlen(self): @@ -117,5 +223,6 @@ class MiscFBTests(TestBase): version = testing.db.dialect.server_version_info(testing.db.connect()) assert len(version) == 3, "Got strange version info: %s" % repr(version) + if __name__ == '__main__': testenv.main() |