diff options
author | Paul Johnston <paj@pajhome.org.uk> | 2007-08-05 23:13:25 +0000 |
---|---|---|
committer | Paul Johnston <paj@pajhome.org.uk> | 2007-08-05 23:13:25 +0000 |
commit | 98230f7c32ff0821984afc1aa4b736fa594c390d (patch) | |
tree | 13bf5e69d2edfc96b995a374280572186dc7b7b7 /lib/sqlalchemy/databases/access.py | |
parent | baedc0f7954e24f854c86208391d373953430adc (diff) | |
download | sqlalchemy-98230f7c32ff0821984afc1aa4b736fa594c390d.tar.gz |
Add initial version of MS Access support
Diffstat (limited to 'lib/sqlalchemy/databases/access.py')
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 417 |
1 files changed, 417 insertions, 0 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py new file mode 100644 index 000000000..ec16d6611 --- /dev/null +++ b/lib/sqlalchemy/databases/access.py @@ -0,0 +1,417 @@ +# access.py +# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk +# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import sys, string, re, datetime, random +from sqlalchemy import sql, engine, schema, ansisql, types, exceptions, pool +import sqlalchemy.engine.default as default + + +class AcNumeric(types.Numeric): + def convert_result_value(self, value, dialect): + return value + + def convert_bind_param(self, value, dialect): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + + def get_col_spec(self): + return "NUMERIC" + +class AcFloat(types.Float): + def get_col_spec(self): + return "FLOAT" + + def convert_bind_param(self, value, dialect): + """By converting to string, we can use Decimal types round-trip.""" + if not value is None: + return str(value) + return None + +class AcInteger(types.Integer): + def get_col_spec(self): + return "INTEGER" + +class AcTinyInteger(types.Integer): + def get_col_spec(self): + return "TINYINT" + +class AcSmallInteger(types.Smallinteger): + def get_col_spec(self): + return "SMALLINT" + +class AcDateTime(types.DateTime): + def __init__(self, *a, **kw): + super(AcDateTime, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcDate(types.Date): + def __init__(self, *a, **kw): + super(AcDate, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcText(types.TEXT): + def get_col_spec(self): + return "MEMO" + +class AcString(types.String): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcUnicode(types.Unicode): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + return value + +class AcChar(types.CHAR): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcBinary(types.Binary): + def get_col_spec(self): + return "BINARY" + +class AcBoolean(types.Boolean): + def get_col_spec(self): + return "YESNO" + + def convert_result_value(self, value, dialect): + if value is None: + return None + return value and True or False + + def convert_bind_param(self, value, dialect): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + +class AcTimeStamp(types.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + +def descriptor(): + return {'name':'access', + 'description':'Microsoft Access', + 'arguments':[ + ('user',"Database user name",None), + ('password',"Database password",None), + ('db',"Path to database file",None), + ]} + +class AccessExecutionContext(default.DefaultExecutionContext): + def _has_implicit_sequence(self, column): + if column.primary_key and column.autoincrement: + if isinstance(column.type, types.Integer) and not column.foreign_key: + if column.default is None or (isinstance(column.default, schema.Sequence) and \ + column.default.optional): + return True + return False + + def post_exec(self): + """If we inserted into a row with a COUNTER column, fetch the ID""" + + if self.compiled.isinsert: + tbl = self.compiled.statement.table + if not hasattr(tbl, 'has_sequence'): + tbl.has_sequence = None + for column in tbl.c: + if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): + tbl.has_sequence = column + break + + if bool(tbl.has_sequence): + if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] + # print "LAST ROW ID", self._last_inserted_ids + + super(AccessExecutionContext, self).post_exec() + + +class AccessDialect(ansisql.ANSIDialect): + colspecs = { + types.Unicode : AcUnicode, + types.Integer : AcInteger, + types.Smallinteger: AcSmallInteger, + types.Numeric : AcNumeric, + types.Float : AcFloat, + types.DateTime : AcDateTime, + types.Date : AcDate, + types.String : AcString, + types.Binary : AcBinary, + types.Boolean : AcBoolean, + types.TEXT : AcText, + types.CHAR: AcChar, + types.TIMESTAMP: AcTimeStamp, + } + + def type_descriptor(self, typeobj): + newobj = types.adapt_type(typeobj, self.colspecs) + return newobj + + def __init__(self, **params): + super(AccessDialect, self).__init__(**params) + self.text_as_varchar = False + self._dtbs = None + + def dbapi(cls): + import win32com.client + win32com.client.gencache.EnsureModule('{00025E01-0000-0000-C000-000000000046}', 0, 5, 0) + + global const, daoEngine + const = win32com.client.constants + daoEngine = win32com.client.Dispatch('DAO.DBEngine.36') + + import pyodbc as module + return module + dbapi = classmethod(dbapi) + + def create_connect_args(self, url): + opts = url.translate_connect_args(['host', 'database', 'username', 'password', 'port']) + connectors = ["Driver={Microsoft Access Driver (*.mdb)}"] + connectors.append("Dbq=%s" % opts["database"]) + user = opts.get("user") + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % opts.get("password", "")) + return [[";".join (connectors)], {}] + + def create_execution_context(self, *args, **kwargs): + return AccessExecutionContext(self, *args, **kwargs) + + def supports_sane_rowcount(self): + return False + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def compiler(self, statement, bindparams, **kwargs): + return AccessCompiler(self, statement, bindparams, **kwargs) + + def schemagenerator(self, *args, **kwargs): + return AccessSchemaGenerator(self, *args, **kwargs) + + def schemadropper(self, *args, **kwargs): + return AccessSchemaDropper(self, *args, **kwargs) + + def defaultrunner(self, connection, **kwargs): + return AccessDefaultRunner(connection, **kwargs) + + def preparer(self): + return AccessIdentifierPreparer(self) + + def do_execute(self, cursor, statement, params, **kwargs): + if params == {}: + params = () + super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs) + + def _execute(self, c, statement, parameters): + try: + if parameters == {}: + parameters = () + c.execute(statement, parameters) + self.context.rowcount = c.rowcount + except Exception, e: + raise exceptions.SQLError(statement, parameters, e) + + def has_table(self, connection, tablename, schema=None): + # This approach seems to be more reliable that using DAO + try: + connection.execute('select top 1 * from [%s]' % tablename) + return True + except Exception, e: + return False + + def reflecttable(self, connection, table): + # This is defined in the function, as it relies on win32com constants, + # that aren't imported until dbapi method is called + if not hasattr(self, 'ischema_names'): + self.ischema_names = { + const.dbByte: AcBinary, + const.dbInteger: AcInteger, + const.dbLong: AcInteger, + const.dbSingle: AcFloat, + const.dbDouble: AcFloat, + const.dbDate: AcDateTime, + const.dbLongBinary: AcBinary, + const.dbMemo: AcText, + const.dbBoolean: AcBoolean, + const.dbText: AcUnicode, # All Access strings are unicode + } + + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + opts = connection.engine.url.translate_connect_args(['host', 'database', 'username', 'password', 'port']) + dtbs = daoEngine.OpenDatabase(opts['database']) + + try: + for tbl in dtbs.TableDefs: + if tbl.Name.lower() == table.name.lower(): + break + else: + raise exceptions.NoSuchTableError(table.name) + + for col in tbl.Fields: + coltype = self.ischema_names[col.Type] + if col.Type == const.dbText: + coltype = coltype(col.Size) + + colargs = \ + { + 'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField), + } + default = col.DefaultValue + + if col.Attributes & const.dbAutoIncrField: + colargs['default'] = schema.Sequence(col.Name + '_seq') + elif default: + if col.Type == const.dbBoolean: + default = default == 'Yes' and '1' or '0' + colargs['default'] = schema.PassiveDefault(sql.text(default)) + + table.append_column(schema.Column(col.Name, coltype, **colargs)) + + # TBD: check constraints + + # Find primary key columns first + for idx in tbl.Indexes: + if idx.Primary: + for col in idx.Fields: + thecol = table.c[col.Name] + table.primary_key.add(thecol) + if isinstance(thecol.type, AcInteger) and \ + not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)): + thecol.autoincrement = False + + # Then add other indexes + for idx in tbl.Indexes: + if not idx.Primary: + if len(idx.Fields) == 1: + col = table.c[idx.Fields[0].Name] + if not col.primary_key: + col.index = True + col.unique = idx.Unique + else: + pass # TBD: multi-column indexes + + + for fk in dtbs.Relations: + if fk.ForeignTable != table.name: + continue + scols = [c.ForeignName for c in fk.Fields] + rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields] + table.append_constraint(schema.ForeignKeyConstraint(scols, rcols)) + + finally: + dtbs.Close() + + def table_names(self, connection, schema): + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + opts = connection.engine.url.translate_connect_args(['host', 'database', 'username', 'password', 'port']) + dtbs = daoEngine.OpenDatabase(opts['database']) + + names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] <> "~TMP"] + dtbs.Close() + return names + + +class AccessCompiler(ansisql.ANSICompiler): + def visit_select_precolumns(self, select): + """Access puts TOP, it's version of LIMIT here """ + s = select.distinct and "DISTINCT " or "" + if select.limit: + s += "TOP %s " % (select.limit) + if select.offset: + raise exceptions.InvalidRequestError('Access does not support LIMIT with an offset') + return s + + def limit_clause(self, select): + """Limit in access is after the select keyword""" + return "" + + def binary_operator_string(self, binary): + """Access uses "mod" instead of "%" """ + return binary.operator == '%' and 'mod' or binary.operator + + def visit_select(self, select): + """Label function calls, so they return a name in cursor.description""" + for i,c in enumerate(select._raw_columns): + if isinstance(c, sql._Function): + select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:]) + + super(AccessCompiler, self).visit_select(select) + + function_rewrites = {'current_date': 'now', + 'current_timestamp': 'now', + 'length': 'len', + } + def visit_function(self, func): + """Access function names differ from the ANSI SQL names; rewrite common ones""" + func.name = self.function_rewrites.get(func.name, func.name) + super(AccessCompiler, self).visit_function(func) + + def for_update_clause(self, select): + """FOR UPDATE is not supported by Access; silently ignore""" + return '' + + +class AccessSchemaGenerator(ansisql.ANSISchemaGenerator): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() + + # install a sequence if we have an implicit IDENTITY column + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_key: + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if not column.nullable: + colspec += " NOT NULL" + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + colspec = self.preparer.format_column(column) + " counter" + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + +class AccessSchemaDropper(ansisql.ANSISchemaDropper): + def visit_index(self, index): + self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name)) + self.execute() + +class AccessDefaultRunner(ansisql.ANSIDefaultRunner): + pass + +class AccessIdentifierPreparer(ansisql.ANSIIdentifierPreparer): + def __init__(self, dialect): + super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + +dialect = AccessDialect +dialect.poolclass = pool.SingletonThreadPool |