diff options
author | Jason Kirtland <jek@discorporate.us> | 2007-10-23 07:38:07 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2007-10-23 07:38:07 +0000 |
commit | 6378c347994c902f7d4e65e54f2b76d01ce603d2 (patch) | |
tree | 1953746106c9fce1f53c16d2638923db5e7f9e7f /lib/sqlalchemy/databases/maxdb.py | |
parent | 21c6fa79b1e5b19c444c9cdc125d67825759330d (diff) | |
download | sqlalchemy-6378c347994c902f7d4e65e54f2b76d01ce603d2.tar.gz |
- Added initial version of MaxDB dialect.
- All optional test Sequences are now optional=True
Diffstat (limited to 'lib/sqlalchemy/databases/maxdb.py')
-rw-r--r-- | lib/sqlalchemy/databases/maxdb.py | 1083 |
1 files changed, 1083 insertions, 0 deletions
diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py new file mode 100644 index 000000000..fcf04bec9 --- /dev/null +++ b/lib/sqlalchemy/databases/maxdb.py @@ -0,0 +1,1083 @@ +# maxdb.py +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the MaxDB database. + +TODO: module docs! + +Overview +-------- + +The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007 +and 7.6.00.037. Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM. +The earlier version has severe ``LEFT JOIN`` limitations and will return +incorrect results from even very simple ORM queries. + +Only the native Python DB-API is currently supported. ODBC driver support +is a future enhancement. + +Implementation Notes +-------------------- + +Also check the DatabaseNotes page on the wiki for detailed information. + +For 'somecol.in_([])' to work, the IN operator's generation must be changed +to cast 'NULL' to a numeric, i.e. NUM(NULL). The DB-API doesn't accept a +bind parameter there, so that particular generation must inline the NULL value, +which depends on [ticket:807]. + +The DB-API is very picky about where bind params may be used in queries. + +Bind params for some functions (e.g. MOD) need type information supplied. +The dialect does not yet do this automatically. + +Max will occasionally throw up 'bad sql, compile again' exceptions for +perfectly valid SQL. The dialect does not currently handle these, more +research is needed. + +MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas. A very +slightly different version of this dialect would be required to support +those versions, and can easily be added if there is demand. Some other +required components such as an Max-aware 'old oracle style' join compiler +(thetas with (+) outer indicators) are already done and available for +integration- email the devel list if you're interested in working on +this. +""" + +import datetime, itertools, re, warnings + +from sqlalchemy import exceptions, schema, sql, util +from sqlalchemy.sql import operators as sql_operators, expression as sql_expr +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import types as sqltypes + + +__all__ = [ + 'MaxString', 'MaxUnicode', 'MaxChar', 'MaxText', 'MaxInteger', + 'MaxSmallInteger', 'MaxNumeric', 'MaxFloat', 'MaxTimestamp', + 'MaxDate', 'MaxTime', 'MaxBoolean', 'MaxBlob', + ] + + +class _StringType(sqltypes.String): + _type = None + + def __init__(self, length=None, encoding=None, **kw): + super(_StringType, self).__init__(length=length, **kw) + self.encoding = encoding + + def get_col_spec(self): + if self.length is None: + spec = 'LONG' + else: + spec = '%s(%s)' % (self._type, self.length) + + if self.encoding is not None: + spec = ' '.join([spec, self.encoding.upper()]) + return spec + + def bind_processor(self, dialect): + if self.encoding == 'unicode': + return None + else: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + else: + return value + return process + + def result_processor(self, dialect): + def process(value): + while True: + if value is None: + return None + elif isinstance(value, unicode): + return value + elif isinstance(value, str): + if self.convert_unicode or dialect.convert_unicode: + return value.decode(dialect.encoding) + else: + return value + elif hasattr(value, 'read'): + # some sort of LONG, snarf and retry + value = value.read(value.remainingLength()) + continue + else: + # unexpected type, return as-is + return value + return process + + +class MaxString(_StringType): + _type = 'VARCHAR' + + def __init__(self, *a, **kw): + super(MaxString, self).__init__(*a, **kw) + + +class MaxUnicode(_StringType): + _type = 'VARCHAR' + + def __init__(self, length=None, **kw): + super(MaxUnicode, self).__init__(length=length, encoding='unicode') + + +class MaxChar(_StringType): + _type = 'CHAR' + + +class MaxText(_StringType): + _type = 'LONG' + + def __init__(self, *a, **kw): + super(MaxText, self).__init__(*a, **kw) + + def get_col_spec(self): + spec = 'LONG' + if self.encoding is not None: + spec = ' '.join((spec, self.encoding)) + elif self.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + + +class MaxInteger(sqltypes.Integer): + def get_col_spec(self): + return 'INTEGER' + + +class MaxSmallInteger(MaxInteger): + def get_col_spec(self): + return 'SMALLINT' + + +class MaxNumeric(sqltypes.Numeric): + """The NUMERIC (also FIXED, DECIMAL) data type.""" + + def get_col_spec(self): + if self.length and self.precision: + return 'NUMERIC(%s, %s)' % (self.precision, self.length) + elif self.length: + return 'NUMERIC(%s)' % self.length + else: + return 'INTEGER' + + +class MaxFloat(sqltypes.Float): + """The FLOAT data type.""" + + def get_col_spec(self): + if self.precision is None: + return 'FLOAT' + else: + return 'FLOAT(%s)' % (self.precision,) + + +class MaxTimestamp(sqltypes.DateTime): + def get_col_spec(self): + return 'TIMESTAMP' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms)) + elif dialect.datetimeformat == 'iso': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms)) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[4:6], value[6:8], + value[8:10], value[10:12], value[12:14], + value[14:])]) + elif dialect.datetimeformat == 'iso': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[5:7], value[8:10], + value[11:13], value[14:16], value[17:19], + value[20:])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxDate(sqltypes.Date): + def get_col_spec(self): + return 'DATE' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%Y%m%d") + elif dialect.datetimeformat == 'iso': + return value.strftime("%Y-%m-%d") + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.date( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + elif dialect.datetimeformat == 'iso': + return datetime.date( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxTime(sqltypes.Time): + def get_col_spec(self): + return 'TIME' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%H%M%S") + elif dialect.datetimeformat == 'iso': + return value.strftime("%H-%M-%S") + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + t = datetime.time( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + return t + elif dialect.datetimeformat == 'iso': + return datetime.time( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxBoolean(sqltypes.Boolean): + def get_col_spec(self): + return 'BOOLEAN' + + +class MaxBlob(sqltypes.Binary): + def get_col_spec(self): + return 'LONG BYTE' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return str(value) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return value.read(value.remainingLength()) + return process + + +colspecs = { + sqltypes.Integer: MaxInteger, + sqltypes.Smallinteger: MaxSmallInteger, + sqltypes.Numeric: MaxNumeric, + sqltypes.Float: MaxFloat, + sqltypes.DateTime: MaxTimestamp, + sqltypes.Date: MaxDate, + sqltypes.Time: MaxTime, + sqltypes.String: MaxString, + sqltypes.Binary: MaxBlob, + sqltypes.Boolean: MaxBoolean, + sqltypes.TEXT: MaxText, + sqltypes.CHAR: MaxChar, + sqltypes.TIMESTAMP: MaxTimestamp, + sqltypes.BLOB: MaxBlob, + sqltypes.Unicode: MaxUnicode, + } + +ischema_names = { + 'boolean': MaxBoolean, + 'int': MaxInteger, + 'integer': MaxInteger, + 'varchar': MaxString, + 'char': MaxChar, + 'character': MaxChar, + 'fixed': MaxNumeric, + 'float': MaxFloat, + 'long': MaxText, + 'long binary': MaxBlob, + 'long unicode': MaxText, + 'long': MaxText, + 'timestamp': MaxTimestamp, + 'date': MaxDate, + 'time': MaxTime + } + + +class MaxDBExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + # DB-API bug: if there were any functions as values, + # then do another select and pull CURRVAL from the + # autoincrement column's implicit sequence... ugh + if self.compiled.isinsert and not self.executemany: + table = self.compiled.statement.table + index, serial_col = _autoserial_column(table) + + if serial_col and (not self.compiled._safeserial or + not(self._last_inserted_ids) or + self._last_inserted_ids[index] in (None, 0)): + if table.schema: + sql = "SELECT %s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + else: + sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + + if self.connection.engine._should_log_info: + self.connection.engine.logger.info(sql) + rs = self.cursor.execute(sql) + id = rs.fetchone()[0] + + if self.connection.engine._should_log_debug: + self.connection.engine.logger.debug([id]) + if not self._last_inserted_ids: + # This shouldn't ever be > 1? Right? + self._last_inserted_ids = \ + [None] * len(table.primary_key.columns) + self._last_inserted_ids[index] = id + + super(MaxDBExecutionContext, self).post_exec() + + def get_result_proxy(self): + if self.cursor.description is not None: + for column in self.cursor.description: + if column[1] in ('Long Binary', 'Long', 'Long Unicode'): + return MaxDBResultProxy(self) + return engine_base.ResultProxy(self) + + +class MaxDBCachedColumnRow(engine_base.RowProxy): + """A RowProxy that only runs result_processors once per column.""" + + def __init__(self, parent, row): + super(MaxDBCachedColumnRow, self).__init__(parent, row) + self.columns = {} + self._row = row + self._parent = parent + + def _get_col(self, key): + if key not in self.columns: + self.columns[key] = self._parent._get_col(self._row, key) + return self.columns[key] + + def __iter__(self): + for i in xrange(len(self._row)): + yield self._get_col(i) + + def __repr__(self): + return repr(list(self)) + + def __eq__(self, other): + return ((other is self) or + (other == tuple([self._get_col(key) + for key in xrange(len(self._row))]))) + def __getitem__(self, key): + if isinstance(key, slice): + indices = key.indices(len(self._row)) + return tuple([self._get_col(i) for i in xrange(*indices)]) + else: + return self._get_col(key) + + def __getattr__(self, name): + try: + return self._get_col(name) + except KeyError: + raise AttributeError(name) + + +class MaxDBResultProxy(engine_base.ResultProxy): + _process_row = MaxDBCachedColumnRow + + +class MaxDBDialect(default.DefaultDialect): + supports_alter = True + supports_unicode_statements = True + max_identifier_length = 32 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + preexecute_sequences = True + + # MaxDB-specific + datetimeformat = 'internal' + + def __init__(self, _raise_known_sql_errors=False, **kw): + super(MaxDBDialect, self).__init__(**kw) + self._raise_known = _raise_known_sql_errors + + def dbapi(cls): + from sapdb import dbapi as _dbapi + return _dbapi + dbapi = classmethod(dbapi) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + return [], opts + + def type_descriptor(self, typeobj): + if isinstance(typeobj, type): + typeobj = typeobj() + if isinstance(typeobj, sqltypes.Unicode): + return typeobj.adapt(MaxUnicode) + else: + return sqltypes.adapt_type(typeobj, colspecs) + + def dbapi_type_map(self): + if self.dbapi is None: + return {} + else: + return { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + + def create_execution_context(self, connection, **kw): + return MaxDBExecutionContext(self, connection, **kw) + + def do_execute(self, cursor, statement, parameters, context=None): + res = cursor.execute(statement, parameters) + if isinstance(res, int) and context is not None: + context._rowcount = res + + def do_release_savepoint(self, connection, name): + # Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts + # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS + # BEGIN SQLSTATE: I7065" + # Note that ROLLBACK TO works fine. In theory, a RELEASE should + # just free up some transactional resources early, before the overall + # COMMIT/ROLLBACK so omitting it should be relatively ok. + pass + + def get_default_schema_name(self, connection): + try: + return self._default_schema_name + except AttributeError: + name = self.identifier_preparer._normalize_name( + connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) + self._default_schema_name = name + return name + + def has_table(self, connection, table_name, schema=None): + denormalize = self.identifier_preparer._denormalize_name + bind = [denormalize(table_name)] + if schema is None: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME=? AND" + " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + else: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME = ? AND" + " TABLES.SCHEMANAME=? ") + bind.append(denormalize(schema)) + + rp = connection.execute(sql, bind) + found = bool(rp.fetchone()) + rp.close() + return found + + def table_names(self, connection, schema): + if schema is None: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=CURRENT_SCHEMA ") + rs = connection.execute(sql) + else: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=? ") + matchname = self.identifier_preparer._denormalize_name(schema) + rs = connection.execute(sql, matchname) + normalize = self.identifier_preparer._normalize_name + return [normalize(row[0]) for row in rs] + + def reflecttable(self, connection, table, include_columns): + denormalize = self.identifier_preparer._denormalize_name + normalize = self.identifier_preparer._normalize_name + + st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' + ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' + 'FROM COLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY POS') + + fk = ('SELECT COLUMNNAME, FKEYNAME, ' + ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' + ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' + ' THEN 1 ELSE 0 END) AS in_schema ' + 'FROM FOREIGNKEYCOLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY FKEYNAME ') + + params = [denormalize(table.name)] + if not table.schema: + st = st % 'CURRENT_SCHEMA' + fk = fk % 'CURRENT_SCHEMA' + else: + st = st % '?' + fk = fk % '?' + params.append(denormalize(table.schema)) + + rows = connection.execute(st, params).fetchall() + if not rows: + raise exceptions.NoSuchTableError(table.fullname) + + include_columns = util.Set(include_columns or []) + + for row in rows: + (name, mode, col_type, encoding, length, precision, + nullable, constant_def, func_def) = row + + name = normalize(name) + + if include_columns and name not in include_columns: + continue + + type_args, type_kw = [], {} + if col_type == 'FIXED': + type_args = length, precision + elif col_type in 'FLOAT': + type_args = length, + elif col_type in ('CHAR', 'VARCHAR'): + type_args = length, + type_kw['encoding'] = encoding + elif col_type == 'LONG': + type_kw['encoding'] = encoding + + try: + type_cls = ischema_names[col_type.lower()] + type_instance = type_cls(*type_args, **type_kw) + except KeyError: + warnings.warn(RuntimeWarning( + "Did not recognize type '%s' of column '%s'" % + (col_type, name))) + type_instance = sqltypes.NullType + + col_kw = {'autoincrement': False} + col_kw['nullable'] = (nullable == 'YES') + col_kw['primary_key'] = (mode == 'KEY') + + if func_def is not None: + if func_def.startswith('SERIAL'): + # strip current numbering + col_kw['default'] = schema.PassiveDefault( + sql.text('SERIAL')) + col_kw['autoincrement'] = True + else: + col_kw['default'] = schema.PassiveDefault( + sql.text(func_def)) + elif constant_def is not None: + col_kw['default'] = schema.PassiveDefault(sql.text( + "'%s'" % constant_def.replace("'", "''"))) + + table.append_column(schema.Column(name, type_instance, **col_kw)) + + fk_sets = itertools.groupby(connection.execute(fk, params), + lambda row: row.FKEYNAME) + for fkeyname, fkey in fk_sets: + fkey = list(fkey) + if include_columns: + key_cols = util.Set([r.COLUMNNAME for r in fkey]) + if key_cols != include_columns: + continue + + columns, referants = [], [] + quote = self.identifier_preparer._maybe_quote_identifier + + for row in fkey: + columns.append(normalize(row.COLUMNNAME)) + if table.schema or not row.in_schema: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFSCHEMANAME', 'REFTABLENAME', + 'REFCOLUMNNAME')])) + else: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) + + constraint_kw = {'name': fkeyname.lower()} + if fkey[0].RULE is not None: + rule = fkey[0].RULE + if rule.startswith('DELETE '): + rule = rule[7:] + constraint_kw['ondelete'] = rule + + table_kw = {} + if table.schema or not row.in_schema: + table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) + + ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), + table_kw.get('schema')) + if ref_key not in table.metadata.tables: + schema.Table(normalize(fkey[0].REFTABLENAME), + table.metadata, + autoload=True, autoload_with=connection, + **table_kw) + + constraint = schema.ForeignKeyConstraint(columns, referants, + **constraint_kw) + table.append_constraint(constraint) + + def has_sequence(self, connection, name): + # [ticket:726] makes this schema-aware. + denormalize = self.identifier_preparer._denormalize_name + sql = ("SELECT sequence_name FROM SEQUENCES " + "WHERE SEQUENCE_NAME=? ") + + rp = connection.execute(sql, denormalize(name)) + found = bool(rp.fetchone()) + rp.close() + return found + + +class MaxDBCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() + operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) + + function_conversion = { + 'CURRENT_DATE': 'DATE', + 'CURRENT_TIME': 'TIME', + 'CURRENT_TIMESTAMP': 'TIMESTAMP', + } + + # These functions must be written without parens when called with no + # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' + bare_functions = util.Set([ + 'CURRENT_SCHEMA', 'DATE', 'TIME', 'TIMESTAMP', 'TIMEZONE', + 'TRANSACTION', 'USER', 'UID', 'USERGROUP', 'UTCDATE']) + + def default_from(self): + return ' FROM DUAL' + + def for_update_clause(self, select): + clause = select.for_update + if clause is True: + return " WITH LOCK EXCLUSIVE" + elif clause is None: + return "" + elif clause == "read": + return " WITH LOCK" + elif clause == "ignore": + return " WITH LOCK (IGNORE) EXCLUSIVE" + elif clause == "nowait": + return " WITH LOCK (NOWAIT) EXCLUSIVE" + elif isinstance(clause, basestring): + return " WITH LOCK %s" % clause.upper() + elif not clause: + return "" + else: + return " WITH LOCK EXCLUSIVE" + + def apply_function_parens(self, func): + if func.name.upper() in self.bare_functions: + return len(func.clauses) > 0 + else: + return True + + def visit_function(self, fn, **kw): + transform = self.function_conversion.get(fn.name.upper(), None) + if transform: + fn = fn._clone() + fn.name = transform + return super(MaxDBCompiler, self).visit_function(fn, **kw) + + def visit_cast(self, cast, **kwargs): + # MaxDB only supports casts * to NUMERIC, * to VARCHAR or + # date/time to VARCHAR. Casts of LONGs will fail. + if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): + return "NUM(%s)" % self.process(cast.clause) + elif isinstance(cast.type, sqltypes.String): + return "CHR(%s)" % self.process(cast.clause) + else: + return self.process(cast.clause) + + def visit_sequence(self, sequence): + if sequence.optional: + return None + else: + return (self.dialect.identifier_preparer.format_sequence(sequence) + + ".NEXTVAL") + + class ColumnSnagger(visitors.ClauseVisitor): + def __init__(self): + self.count = 0 + self.column = None + def visit_column(self, column): + self.column = column + self.count += 1 + + def _find_labeled_columns(self, columns, use_labels=False): + labels = {} + for column in columns: + if isinstance(column, basestring): + continue + snagger = self.ColumnSnagger() + snagger.traverse(column) + if snagger.count == 1: + if isinstance(column, sql_expr._Label): + labels[unicode(snagger.column)] = column.name + elif use_labels: + labels[unicode(snagger.column)] = column._label + + return labels + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # ORDER BY clauses in DISTINCT queries must reference aliased + # inner columns by alias name, not true column name. + if order_by and getattr(select, '_distinct', False): + labels = self._find_labeled_columns(select.inner_columns, + select.use_labels) + if labels: + for needs_alias in labels.keys(): + r = re.compile(r'(^| )(%s)(,| |$)' % + re.escape(needs_alias)) + order_by = r.sub((r'\1%s\3' % labels[needs_alias]), + order_by) + + # No ORDER BY in subqueries. + if order_by: + if self.is_subquery(select): + # It's safe to simply drop the ORDER BY if there is no + # LIMIT. Right? Other dialects seem to get away with + # dropping order. + if select._limit: + raise exceptions.InvalidRequestError( + "MaxDB does not support ORDER BY in subqueries") + else: + return "" + return " ORDER BY " + order_by + else: + return "" + + def get_select_precolumns(self, select): + # Convert a subquery's LIMIT to TOP + sql = select._distinct and 'DISTINCT ' or '' + if self.is_subquery(select) and select._limit: + if select._offset: + raise exceptions.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + sql += 'TOP %s ' % select._limit + return sql + + def limit_clause(self, select): + # The docs say offsets are supported with LIMIT. But they're not. + # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate? + if self.is_subquery(select): + # sub queries need TOP + return '' + elif select._offset: + raise exceptions.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + else: + return ' \n LIMIT %s' % (select._limit,) + + def visit_insert(self, insert): + self.isinsert = True + self._safeserial = True + + colparams = self._get_colparams(insert) + for value in (insert.parameters or {}).itervalues(): + if isinstance(value, sql_expr._Function): + self._safeserial = False + break + + return ''.join(('INSERT INTO ', + self.preparer.format_table(insert.table), + ' (', + ', '.join([self.preparer.format_column(c[0]) + for c in colparams]), + ') VALUES (', + ', '.join([c[1] for c in colparams]), + ')')) + + +class MaxDBDefaultRunner(engine_base.DefaultRunner): + def visit_sequence(self, seq): + if seq.optional: + return None + return self.execute_string("SELECT %s.NEXTVAL FROM DUAL" % ( + self.dialect.identifier_preparer.format_sequence(seq))) + + +class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = util.Set([ + 'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha', + 'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary', + 'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char', + 'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos', + 'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime', + 'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth', + 'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default', + 'degrees', 'delete', 'digits', 'distinct', 'double', 'except', + 'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for', + 'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest', + 'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore', + 'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal', + 'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left', + 'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long', + 'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime', + 'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod', + 'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround', + 'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on', + 'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians', + 'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round', + 'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd', + 'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some', + 'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev', + 'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba', + 'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone', + 'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc', + 'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper', + 'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values', + 'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when', + 'where', 'with', 'year', 'zoned' ]) + + def _normalize_name(self, name): + if name is None: + return None + if name.isupper(): + lc_name = name.lower() + if not self._requires_quotes(lc_name): + return lc_name + return name + + def _denormalize_name(self, name): + if name is None: + return None + elif (name.islower() and + not self._requires_quotes(name)): + return name.upper() + else: + return name + + def _maybe_quote_identifier(self, name): + if self._requires_quotes(name): + return self.quote_identifier(name) + else: + return name + + +class MaxDBSchemaGenerator(compiler.SchemaGenerator): + def get_column_specification(self, column, **kw): + colspec = [self.preparer.format_column(column), + column.type.dialect_impl(self.dialect).get_col_spec()] + + if not column.nullable: + colspec.append('NOT NULL') + + default = column.default + default_str = self.get_column_default_string(column) + + # No DDL default for columns specified with non-optional sequence- + # this defaulting behavior is entirely client-side. (And as a + # consequence, non-reflectable.) + if (default and isinstance(default, schema.Sequence) and + not default.optional): + pass + # Regular default + elif default_str is not None: + colspec.append('DEFAULT %s' % default_str) + # Assign DEFAULT SERIAL heuristically + elif column.primary_key and column.autoincrement: + # For SERIAL on a non-primary key member, use + # PassiveDefault(text('SERIAL')) + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + (isinstance(c.type, sqltypes.Integer) or + (isinstance(c.type, MaxNumeric) and + c.type.precision)) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('DEFAULT SERIAL') + except IndexError: + pass + + return ' '.join(colspec) + + def get_column_default_string(self, column): + if isinstance(column.default, schema.PassiveDefault): + if isinstance(column.default.arg, basestring): + if isinstance(column.type, sqltypes.Integer): + return str(column.default.arg) + else: + return "'%s'" % column.default.arg + else: + return unicode(self._compile(column.default.arg, None)) + else: + return None + + def visit_sequence(self, sequence): + """Creates a SEQUENCE. + + TODO: move to module doc? + + start + With an integer value, set the START WITH option. + + increment + An integer value to increment by. Default is the database default. + + maxdb_minvalue + maxdb_maxvalue + With an integer value, sets the corresponding sequence option. + + maxdb_no_minvalue + maxdb_no_maxvalue + Defaults to False. If true, sets the corresponding sequence option. + + maxdb_cycle + Defaults to False. If true, sets the CYCLE option. + + maxdb_cache + With an integer value, sets the CACHE option. + + maxdb_no_cache + Defaults to False. If true, sets NOCACHE. + """ + + if (not sequence.optional and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name))): + + ddl = ['CREATE SEQUENCE', + self.preparer.format_sequence(sequence)] + + sequence.increment = 1 + + if sequence.increment is not None: + ddl.extend(('INCREMENT BY', str(sequence.increment))) + + if sequence.start is not None: + ddl.extend(('START WITH', str(sequence.start))) + + opts = dict([(pair[0][6:].lower(), pair[1]) + for pair in sequence.kwargs.items() + if pair[0].startswith('maxdb_')]) + + if 'maxvalue' in opts: + ddl.extend(('MAXVALUE', str(opts['maxvalue']))) + elif opts.get('no_maxvalue', False): + ddl.append('NOMAXVALUE') + if 'minvalue' in opts: + ddl.extend(('MINVALUE', str(opts['minvalue']))) + elif opts.get('no_minvalue', False): + ddl.append('NOMINVALUE') + + if opts.get('cycle', False): + ddl.append('CYCLE') + + if 'cache' in opts: + ddl.extend(('CACHE', str(opts['cache']))) + elif opts.get('no_cache', False): + ddl.append('NOCACHE') + + self.append(' '.join(ddl)) + self.execute() + + +class MaxDBSchemaDropper(compiler.SchemaDropper): + def visit_sequence(self, sequence): + if (not sequence.optional and + (not self.checkfirst or + self.dialect.has_sequence(self.connection, sequence.name))): + self.append("DROP SEQUENCE %s" % + self.preparer.format_sequence(sequence)) + self.execute() + + +def _autoserial_column(table): + """Finds the effective DEFAULT SERIAL column of a Table, if any.""" + + for index, col in enumerate(table.primary_key.columns): + if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and + col.autoincrement): + if isinstance(col.default, schema.Sequence): + if col.default.optional: + return index, col + elif (col.default is None or + (not isinstance(col.default, schema.PassiveDefault))): + return index, col + + return None, None + +def descriptor(): + return {'name': 'maxdb', + 'description': 'MaxDB', + 'arguments': [ + ('user', "Database Username", None), + ('password', "Database Password", None), + ('database', "Database Name", None), + ('host', "Hostname", None)]} + +dialect = MaxDBDialect +dialect.preparer = MaxDBIdentifierPreparer +dialect.statement_compiler = MaxDBCompiler +dialect.schemagenerator = MaxDBSchemaGenerator +dialect.schemadropper = MaxDBSchemaDropper +dialect.defaultrunner = MaxDBDefaultRunner + |