diff options
author | Paul Johnston <paj@pajhome.org.uk> | 2007-10-12 23:39:28 +0000 |
---|---|---|
committer | Paul Johnston <paj@pajhome.org.uk> | 2007-10-12 23:39:28 +0000 |
commit | aafe57ab5d3db83f982dee877d6427dfbe97dc2c (patch) | |
tree | 5830ce75b6dc7461145f04e3e2620a5e542417ab /lib/sqlalchemy/databases/access.py | |
parent | 2585a470c0c31254da8b3f51e927704e403d5d35 (diff) | |
download | sqlalchemy-aafe57ab5d3db83f982dee877d6427dfbe97dc2c.tar.gz |
A few fixes to the access dialect
Diffstat (limited to 'lib/sqlalchemy/databases/access.py')
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 53 |
1 files changed, 22 insertions, 31 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 9f4847c45..c6e6107bf 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -7,10 +7,9 @@ import random from sqlalchemy import sql, schema, types, exceptions, pool -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, expression from sqlalchemy.engine import default, base - class AcNumeric(types.Numeric): def result_processor(self, dialect): return None @@ -149,11 +148,13 @@ class AccessExecutionContext(default.DefaultExecutionContext): break if bool(tbl.has_sequence): - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] - # print "LAST ROW ID", self._last_inserted_ids + # TBD: for some reason _last_inserted_ids doesn't exist here + # (but it does at corresponding point in mssql???) + #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:] + # print "LAST ROW ID", self._last_inserted_ids super(AccessExecutionContext, self).post_exec() @@ -177,7 +178,7 @@ class AccessDialect(default.DefaultDialect): } supports_sane_rowcount = False - + supports_sane_multi_rowcount = False def type_descriptor(self, typeobj): newobj = types.adapt_type(typeobj, self.colspecs) @@ -217,21 +218,6 @@ class AccessDialect(default.DefaultDialect): def last_inserted_ids(self): return self.context.last_inserted_ids - def compiler(self, statement, bindparams, **kwargs): - return AccessCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return AccessSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return AccessSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection, **kwargs): - return AccessDefaultRunner(connection, **kwargs) - - def preparer(self): - return AccessIdentifierPreparer(self) - def do_execute(self, cursor, statement, params, **kwargs): if params == {}: params = () @@ -254,7 +240,7 @@ class AccessDialect(default.DefaultDialect): except Exception, e: return False - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): # This is defined in the function, as it relies on win32com constants, # that aren't imported until dbapi method is called if not hasattr(self, 'ischema_names'): @@ -364,13 +350,11 @@ class AccessCompiler(compiler.DefaultCompiler): """Access uses "mod" instead of "%" """ return binary.operator == '%' and 'mod' or binary.operator - def visit_select(self, select): - """Label function calls, so they return a name in cursor.description""" - for i,c in enumerate(select._raw_columns): - if isinstance(c, sql._Function): - select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:]) - - super(AccessCompiler, self).visit_select(select) + def label_select_column(self, select, column): + if isinstance(column, expression._Function): + return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + else: + return super(AccessCompiler, self).label_select_column(select, column) function_rewrites = {'current_date': 'now', 'current_timestamp': 'now', @@ -418,9 +402,16 @@ class AccessDefaultRunner(base.DefaultRunner): pass class AccessIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = compiler.RESERVED_WORDS.copy() + reserved_words.update(['value', 'text']) def __init__(self, dialect): super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') dialect = AccessDialect dialect.poolclass = pool.SingletonThreadPool +dialect.statement_compiler = AccessCompiler +dialect.schemagenerator = AccessSchemaGenerator +dialect.schemadropper = AccessSchemaDropper +dialect.preparer = AccessIdentifierPreparer +dialect.defaultrunner = AccessDefaultRunner |