diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
commit | ed4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch) | |
tree | c1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/databases | |
parent | 3a8e235af64e36b3b711df1f069d32359fe6c967 (diff) | |
download | sqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz |
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/databases')
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 57 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/information_schema.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/informix.py | 63 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 113 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 107 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 255 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 293 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 41 |
8 files changed, 482 insertions, 460 deletions
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index a02781c84..07f07644f 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -5,15 +5,11 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, types +import warnings -from sqlalchemy import util +from sqlalchemy import util, sql, schema, ansisql, exceptions import sqlalchemy.engine.default as default -import sqlalchemy.sql as sql -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions _initialized_kb = False @@ -176,7 +172,7 @@ class FBDialect(ansisql.ANSIDialect): else: return False - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): #TODO: map these better column_func = { 14 : lambda r: sqltypes.String(r['FLEN']), # TEXT @@ -254,11 +250,20 @@ class FBDialect(ansisql.ANSIDialect): while row: name = row['FNAME'] - args = [lower_if_possible(name)] + python_name = lower_if_possible(name) + if include_columns and python_name not in include_columns: + continue + args = [python_name] kw = {} # get the data types and lengths - args.append(column_func[row['FTYPE']](row)) + coltype = column_func.get(row['FTYPE'], None) + if coltype is None: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (str(row['FTYPE']), name))) + coltype = sqltypes.NULLTYPE + else: + coltype = coltype(row) + args.append(coltype) # is it a primary key? kw['primary_key'] = name in pkfields @@ -301,39 +306,39 @@ class FBDialect(ansisql.ANSIDialect): class FBCompiler(ansisql.ANSICompiler): """Firebird specific idiosincrasies""" - def visit_alias(self, alias): + def visit_alias(self, alias, asfrom=False, **kwargs): # Override to not use the AS keyword which FB 1.5 does not like - self.froms[alias] = self.get_from_text(alias.original) + " " + self.preparer.format_alias(alias) - self.strings[alias] = self.get_str(alias.original) + if asfrom: + return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias) + else: + return self.process(alias.original, asfrom=True) def visit_function(self, func): if len(func.clauses): - super(FBCompiler, self).visit_function(func) + return super(FBCompiler, self).visit_function(func) else: - self.strings[func] = func.name + return func.name - def visit_insert_column(self, column, parameters): - # all column primary key inserts must be explicitly present - if column.primary_key: - parameters[column.key] = None + def uses_sequences_for_inserts(self): + return True - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list Firebird puts the limit and offset right after the ``SELECT``... """ result = "" - if select.limit: - result += " FIRST %d " % select.limit - if select.offset: - result +=" SKIP %d " % select.offset - if select.distinct: + if select._limit: + result += " FIRST %d " % select._limit + if select._offset: + result +=" SKIP %d " % select._offset + if select._distinct: result += " DISTINCT " return result def limit_clause(self, select): - """Already taken care of in the `visit_select_precolumns` method.""" + """Already taken care of in the `get_select_precolumns` method.""" return "" @@ -364,7 +369,7 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper): class FBDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.connection) + c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection) return self.connection.execute_compiled(c).scalar() def visit_sequence(self, seq): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 81c44dcaa..93f47de15 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -1,4 +1,6 @@ -from sqlalchemy import sql, schema, exceptions, select, MetaData, Table, Column, String, Integer +import sqlalchemy.sql as sql +import sqlalchemy.exceptions as exceptions +from sqlalchemy import select, MetaData, Table, Column, String, Integer from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint ischema = MetaData() @@ -96,8 +98,7 @@ class ISchema(object): return self.cache[name] -def reflecttable(connection, table, ischema_names): - +def reflecttable(connection, table, include_columns, ischema_names): key_constraints = pg_key_constraints if table.schema is not None: @@ -128,7 +129,9 @@ def reflecttable(connection, table, ischema_names): row[columns.c.numeric_scale], row[columns.c.column_default] ) - + if include_columns and name not in include_columns: + continue + args = [] for a in (charlen, numericprec, numericscale): if a is not None: @@ -139,7 +142,7 @@ def reflecttable(connection, table, ischema_names): colargs= [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(Column(name, coltype, nullable=nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 2fb508280..f3a6cf60e 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -5,20 +5,11 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import datetime, warnings -import sys, StringIO, string , random -import datetime -from decimal import Decimal - -import sqlalchemy.util as util -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine +from sqlalchemy import sql, schema, ansisql, exceptions, pool import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions -import sqlalchemy.pool as pool # for offset @@ -128,7 +119,7 @@ class InfoBoolean(sqltypes.Boolean): elif value is None: return None else: - return value and True or False + return value and True or False colspecs = { @@ -262,7 +253,7 @@ class InfoDialect(ansisql.ANSIDialect): cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() ) return bool( cursor.fetchone() is not None ) - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() ) rows = c.fetchall() if not rows : @@ -289,6 +280,10 @@ class InfoDialect(ansisql.ANSIDialect): raise exceptions.NoSuchTableError(table.name) for name , colattr , collength , default , colno in rows: + name = name.lower() + if include_columns and name not in include_columns: + continue + # in 7.31, coltype = 0x000 # ^^-- column type # ^-- 1 not null , 0 null @@ -306,14 +301,16 @@ class InfoDialect(ansisql.ANSIDialect): scale = 0 coltype = InfoNumeric(precision, scale) else: - coltype = ischema_names.get(coltype) + try: + coltype = ischema_names[coltype] + except KeyError: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name))) + coltype = sqltypes.NULLTYPE colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - name = name.lower() - table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs)) # FK @@ -372,20 +369,20 @@ class InfoCompiler(ansisql.ANSICompiler): def default_from(self): return " from systables where tabname = 'systables' " - def visit_select_precolumns( self , select ): - s = select.distinct and "DISTINCT " or "" + def get_select_precolumns( self , select ): + s = select._distinct and "DISTINCT " or "" # only has limit - if select.limit: - off = select.offset or 0 - s += " FIRST %s " % ( select.limit + off ) + if select._limit: + off = select._offset or 0 + s += " FIRST %s " % ( select._limit + off ) else: s += "" return s def visit_select(self, select): - if select.offset: - self.offset = select.offset - self.limit = select.limit or 0 + if select._offset: + self.offset = select._offset + self.limit = select._limit or 0 # the column in order by clause must in select too def __label( c ): @@ -393,13 +390,14 @@ class InfoCompiler(ansisql.ANSICompiler): return c._label.lower() except: return '' - + + # TODO: dont modify the original select, generate a new one a = [ __label(c) for c in select._raw_columns ] for c in select.order_by_clause.clauses: if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid': select.append_column( c ) - ansisql.ANSICompiler.visit_select(self, select) + return ansisql.ANSICompiler.visit_select(self, select) def limit_clause(self, select): return "" @@ -414,23 +412,20 @@ class InfoCompiler(ansisql.ANSICompiler): def visit_function( self , func ): if func.name.lower() == 'current_date': - self.strings[func] = "today" + return "today" elif func.name.lower() == 'current_time': - self.strings[func] = "CURRENT HOUR TO SECOND" + return "CURRENT HOUR TO SECOND" elif func.name.lower() in ( 'current_timestamp' , 'now' ): - self.strings[func] = "CURRENT YEAR TO SECOND" + return "CURRENT YEAR TO SECOND" else: - ansisql.ANSICompiler.visit_function( self , func ) + return ansisql.ANSICompiler.visit_function( self , func ) def visit_clauselist(self, list): try: li = [ c for c in list.clauses if c.name != 'oid' ] except: li = [ c for c in list.clauses ] - if list.parens: - self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in li] if s is not None ], ', ') + ")" - else: - self.strings[list] = string.join([s for s in [self.get_str(c) for c in li] if s is not None], ', ') + return ', '.join([s for s in [self.process(c) for c in li] if s is not None]) class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, first_pk=False): diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index ba1c0fd9d..206291404 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -25,7 +25,7 @@ * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT`` -* ``select.limit`` implemented as ``SELECT TOP n`` +* ``select._limit`` implemented as ``SELECT TOP n`` Known issues / TODO: @@ -39,16 +39,11 @@ Known issues / TODO: """ -import sys, StringIO, string, types, re, datetime, random +import datetime, random, warnings -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +from sqlalchemy import sql, schema, ansisql, exceptions import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions - +from sqlalchemy.engine import default class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): @@ -500,7 +495,7 @@ class MSSQLDialect(ansisql.ANSIDialect): row = c.fetchone() return row is not None - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): import sqlalchemy.databases.information_schema as ischema # Get base columns @@ -532,16 +527,22 @@ class MSSQLDialect(ansisql.ANSIDialect): row[columns.c.numeric_scale], row[columns.c.column_default] ) + if include_columns and name not in include_columns: + continue args = [] for a in (charlen, numericprec, numericscale): if a is not None: args.append(a) - coltype = self.ischema_names[type] + coltype = self.ischema_names.get(type, None) if coltype == MSString and charlen == -1: coltype = MSText() else: - if coltype == MSNVarchar and charlen == -1: + if coltype is None: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name))) + coltype = sqltypes.NULLTYPE + + elif coltype == MSNVarchar and charlen == -1: charlen = None coltype = coltype(*args) colargs= [] @@ -812,12 +813,12 @@ class MSSQLCompiler(ansisql.ANSICompiler): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) self.tablealiases = {} - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ - s = select.distinct and "DISTINCT " or "" - if select.limit: - s += "TOP %s " % (select.limit,) - if select.offset: + s = select._distinct and "DISTINCT " or "" + if select._limit: + s += "TOP %s " % (select._limit,) + if select._offset: raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s @@ -825,49 +826,50 @@ class MSSQLCompiler(ansisql.ANSICompiler): # Limit in mssql is after the select keyword return "" - def visit_table(self, table): + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if not self.tablealiases.has_key(table): + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + # alias schema-qualified tables - if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table): - alias = table.alias() - self.tablealiases[table] = alias - self.traverse(alias) - self.froms[('alias', table)] = self.froms[table] - for c in alias.c: - self.traverse(c) - self.traverse(alias.oid_column) - self.tablealiases[alias] = self.froms[table] - self.froms[table] = self.froms[alias] + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) else: - super(MSSQLCompiler, self).visit_table(table) + return super(MSSQLCompiler, self).visit_table(table, **kwargs) - def visit_alias(self, alias): + def visit_alias(self, alias, **kwargs): # translate for schema-qualified table aliases - if self.froms.has_key(('alias', alias.original)): - self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name - self.strings[alias] = "" - else: - super(MSSQLCompiler, self).visit_alias(alias) + self.tablealiases[alias.original] = alias + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) def visit_column(self, column): - # translate for schema-qualified table aliases - super(MSSQLCompiler, self).visit_column(column) - if column.table is not None and self.tablealiases.has_key(column.table): - self.strings[column] = \ - self.strings[self.tablealiases[column.table].corresponding_column(column)] + if column.table is not None: + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + return self.process(t.corresponding_column(column)) + return super(MSSQLCompiler, self).visit_column(column) def visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=': - binary.left, binary.right = binary.right, binary.left - super(MSSQLCompiler, self).visit_binary(binary) - - def visit_select(self, select): - # label function calls, so they return a name in cursor.description - for i,c in enumerate(select._raw_columns): - if isinstance(c, sql._Function): - select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:]) + if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq: + return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(MSSQLCompiler, self).visit_binary(binary) - super(MSSQLCompiler, self).visit_select(select) + def label_select_column(self, select, column): + if isinstance(column, sql._Function): + return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + else: + return super(MSSQLCompiler, self).label_select_column(select, column) function_rewrites = {'current_date': 'getdate', 'length': 'len', @@ -881,10 +883,10 @@ class MSSQLCompiler(ansisql.ANSICompiler): return '' def order_by_clause(self, select): - order_by = self.get_str(select.order_by_clause) + order_by = self.process(select._order_by_clause) # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not select.is_subquery or select.limit): + if order_by and (not self.is_subquery(select) or select._limit): return " ORDER BY " + order_by else: return "" @@ -916,10 +918,12 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote_identifier(index.name))) + self.preparer.quote_identifier(index.table.name), + self.preparer.quote_identifier(index.name) + )) self.execute() + class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): # TODO: does ms-sql have standalone sequences ? pass @@ -940,4 +944,3 @@ dialect = MSSQLDialect - diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index bac0e5e12..26800e32b 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, datetime, inspect, warnings, weakref +import re, datetime, inspect, warnings, weakref, operator from sqlalchemy import sql, schema, ansisql from sqlalchemy.engine import default @@ -12,13 +12,13 @@ import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import sqlalchemy.util as util from array import array as _array +from decimal import Decimal try: from threading import Lock except ImportError: from dummy_threading import Lock - RESERVED_WORDS = util.Set( ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', @@ -60,7 +60,6 @@ RESERVED_WORDS = util.Set( 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', 'read_only', 'read_write', # 5.1 ]) - _per_connection_mutex = Lock() class _NumericType(object): @@ -137,7 +136,7 @@ class _StringType(object): class MSNumeric(sqltypes.Numeric, _NumericType): """MySQL NUMERIC type""" - def __init__(self, precision = 10, length = 2, **kw): + def __init__(self, precision = 10, length = 2, asdecimal=True, **kw): """Construct a NUMERIC. precision @@ -157,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType): """ _NumericType.__init__(self, **kw) - sqltypes.Numeric.__init__(self, precision, length) - + sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal) + def get_col_spec(self): if self.precision is None: return self._extend("NUMERIC") else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class MSDecimal(MSNumeric): """MySQL DECIMAL type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DECIMAL. precision @@ -187,7 +195,7 @@ class MSDecimal(MSNumeric): underlying database API, which continue to be numeric. """ - super(MSDecimal, self).__init__(precision, length, **kw) + super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is None: @@ -200,7 +208,7 @@ class MSDecimal(MSNumeric): class MSDouble(MSNumeric): """MySQL DOUBLE type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DOUBLE. precision @@ -222,7 +230,7 @@ class MSDouble(MSNumeric): if ((precision is None and length is not None) or (precision is not None and length is None)): raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.") - super(MSDouble, self).__init__(precision, length, **kw) + super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is not None and self.length is not None: @@ -235,7 +243,7 @@ class MSDouble(MSNumeric): class MSFloat(sqltypes.Float, _NumericType): """MySQL FLOAT type""" - def __init__(self, precision=10, length=None, **kw): + def __init__(self, precision=10, length=None, asdecimal=False, **kw): """Construct a FLOAT. precision @@ -257,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType): if length is not None: self.length=length _NumericType.__init__(self, **kw) - sqltypes.Float.__init__(self, precision) + sqltypes.Float.__init__(self, precision, asdecimal=asdecimal) def get_col_spec(self): if hasattr(self, 'length') and self.length is not None: @@ -267,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType): else: return self._extend("FLOAT") + def convert_bind_param(self, value, dialect): + return value + + class MSInteger(sqltypes.Integer, _NumericType): """MySQL INTEGER type""" @@ -955,7 +967,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext): if self.compiled.isinsert: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - + + def is_select(self): + return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None + class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) @@ -1044,6 +1059,27 @@ class MySQLDialect(ansisql.ANSIDialect): except: pass + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("XA RECOVER")) + return [row['data'][0:row['gtrid_length']] for row in resultset] + def is_disconnect(self, e): return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055) @@ -1088,7 +1124,7 @@ class MySQLDialect(ansisql.ANSIDialect): version.append(n) return tuple(version) - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" decode_from = self._detect_charset(connection) @@ -1111,6 +1147,9 @@ class MySQLDialect(ansisql.ANSIDialect): # leave column names as unicode name = name.decode(decode_from) + + if include_columns and name not in include_columns: + continue match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type) col_type = match.group(1) @@ -1118,7 +1157,11 @@ class MySQLDialect(ansisql.ANSIDialect): extra_1 = match.group(3) extra_2 = match.group(4) - coltype = ischema_names.get(col_type, MSString) + try: + coltype = ischema_names[col_type] + except KeyError: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name))) + coltype = sqltypes.NULLTYPE kw = {} if extra_1 is not None: @@ -1156,7 +1199,6 @@ class MySQLDialect(ansisql.ANSIDialect): if not row: raise exceptions.NoSuchTableError(table.fullname) desc = row[1].strip() - row.close() tabletype = '' lastparen = re.search(r'\)[^\)]*\Z', desc) @@ -1223,7 +1265,6 @@ class MySQLDialect(ansisql.ANSIDialect): cs = True else: cs = row[1] in ('0', 'OFF' 'off') - row.close() cache['lower_case_table_names'] = cs self.per_connection[raw_connection] = cache return cache.get('lower_case_table_names') @@ -1266,14 +1307,21 @@ class _MySQLPythonRowProxy(object): class MySQLCompiler(ansisql.ANSICompiler): - def visit_cast(self, cast): - + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y), + operator.mod : '%%' + } + ) + + def visit_cast(self, cast, **kwargs): if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): - return super(MySQLCompiler, self).visit_cast(cast) + return super(MySQLCompiler, self).visit_cast(cast, **kwargs) else: # so just skip the CAST altogether for now. # TODO: put whatever MySQL does for CAST here. - self.strings[cast] = self.strings[cast.clause] + return self.process(cast.clause) def for_update_clause(self, select): if select.for_update == 'read': @@ -1283,20 +1331,15 @@ class MySQLCompiler(ansisql.ANSICompiler): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: - # striaght from the MySQL docs, I kid you not + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + # straight from the MySQL docs, I kid you not text += " \n LIMIT 18446744073709551615" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def binary_operator_string(self, binary): - if binary.operator == '%': - return '%%' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 9d7d6a112..d3aa2e268 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -5,9 +5,9 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, re, warnings +import re, warnings, operator -from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging +from sqlalchemy import util, sql, schema, ansisql, exceptions, logging from sqlalchemy.engine import default, base import sqlalchemy.types as sqltypes @@ -88,8 +88,11 @@ class OracleText(sqltypes.TEXT): def convert_result_value(self, value, dialect): if value is None: return None - else: + elif hasattr(value, 'read'): + # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str return super(OracleText, self).convert_result_value(value.read(), dialect) + else: + return super(OracleText, self).convert_result_value(value, dialect) class OracleRaw(sqltypes.Binary): @@ -178,25 +181,31 @@ class OracleExecutionContext(default.DefaultExecutionContext): super(OracleExecutionContext, self).pre_exec() if self.dialect.auto_setinputsizes: self.set_input_sizes() + if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list): + for key in self.compiled_parameters: + (bindparam, name, value) = self.compiled_parameters.get_parameter(key) + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[name] = self.out_parameters[name] def get_result_proxy(self): + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None: + for k in self.out_parameters: + type = self.compiled_parameters.get_type(k) + self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect) + else: + for k in self.out_parameters: + self.out_parameters[k] = self.out_parameters[k].getvalue() + if self.cursor.description is not None: - if self.dialect.auto_convert_lobs and self.typemap is None: - typemap = {} - binary = False - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - binary = True - typemap[column[0].lower()] = OracleBinary() - self.typemap = typemap - if binary: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: return base.BufferedColumnResultProxy(self) - else: - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - return base.BufferedColumnResultProxy(self) return base.ResultProxy(self) @@ -208,11 +217,26 @@ class OracleDialect(ansisql.ANSIDialect): self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) self.auto_setinputsizes = auto_setinputsizes self.auto_convert_lobs = auto_convert_lobs + if self.dbapi is not None: self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)] else: self.ORACLE_BINARY_TYPES = [] + def dbapi_type_map(self): + if self.dbapi is None or not self.auto_convert_lobs: + return {} + else: + return { + self.dbapi.NUMBER: OracleInteger(), + self.dbapi.CLOB: OracleText(), + self.dbapi.BLOB: OracleBinary(), + self.dbapi.STRING: OracleString(), + self.dbapi.TIMESTAMP: OracleTimestamp(), + self.dbapi.BINARY: OracleRaw(), + datetime.datetime: OracleDate() + } + def dbapi(cls): import cx_Oracle return cx_Oracle @@ -251,7 +275,7 @@ class OracleDialect(ansisql.ANSIDialect): return 30 def oid_column_name(self, column): - if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select): + if not isinstance(column.table, (sql.TableClause, sql.Select)): return None else: return "rowid" @@ -341,7 +365,7 @@ class OracleDialect(ansisql.ANSIDialect): return name, owner, dblink raise - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer if not preparer.should_quote(table): name = table.name.upper() @@ -363,6 +387,13 @@ class OracleDialect(ansisql.ANSIDialect): #print "ROW:" , row (colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) + # if name comes back as all upper, assume its case folded + if (colname.upper() == colname): + colname = colname.lower() + + if include_columns and colname not in include_columns: + continue + # INTEGER if the scale is 0 and precision is null # NUMBER if the scale and precision are both null # NUMBER(9,2) if the precision is 9 and the scale is 2 @@ -382,16 +413,13 @@ class OracleDialect(ansisql.ANSIDialect): try: coltype = ischema_names[coltype] except KeyError: - raise exceptions.AssertionError("Can't get coltype for type '%s' on colname '%s'" % (coltype, colname)) + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname))) + coltype = sqltypes.NULLTYPE colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - # if name comes back as all upper, assume its case folded - if (colname.upper() == colname): - colname = colname.lower() - table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) if not len(table.columns): @@ -458,16 +486,27 @@ class OracleDialect(ansisql.ANSIDialect): OracleDialect.logger = logging.class_logger(OracleDialect) +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + def __init__(self, column): + self.column = column + class OracleCompiler(ansisql.ANSICompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False. """ + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : lambda x, y:"mod(%s, %s)" % (x, y) + } + ) + def __init__(self, *args, **kwargs): super(OracleCompiler, self).__init__(*args, **kwargs) - # we have to modify SELECT objects a little bit, so store state here - self._select_state = {} + self.__wheres = {} def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. @@ -480,49 +519,46 @@ class OracleCompiler(ansisql.ANSICompiler): def apply_function_parens(self, func): return len(func.clauses) > 0 - def visit_join(self, join): + def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return ansisql.ANSICompiler.visit_join(self, join) - - self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right) - where = self.wheres.get(join.left, None) + return ansisql.ANSICompiler.visit_join(self, join, **kwargs) + + (where, parentjoin) = self.__wheres.get(join, (None, None)) + + class VisitOn(sql.ClauseVisitor): + def visit_binary(s, binary): + if binary.operator == operator.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + if where is not None: - self.wheres[join] = sql.and_(where, join.onclause) + self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin) else: - self.wheres[join] = join.onclause -# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause) - self.strings[join] = self.froms[join] - - if join.isouter: - # if outer join, push on the right side table as the current "outertable" - self._outertable = join.right - - # now re-visit the onclause, which will be used as a where clause - # (the first visit occured via the Join object itself right before it called visit_join()) - self.traverse(join.onclause) - - self._outertable = None - - self.wheres[join].accept_visitor(self) + self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join) - def visit_insert_sequence(self, column, sequence, parameters): - """This is the `sequence` equivalent to ``ANSICompiler``'s - `visit_insert_column_default` which ensures that the column is - present in the generated column list. - """ - - parameters.setdefault(column.key, None) + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + + def get_whereclause(self, f): + if f in self.__wheres: + return self.__wheres[f][0] + else: + return None + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def uses_sequences_for_inserts(self): + return True - def visit_alias(self, alias): + def visit_alias(self, alias, asfrom=False, **kwargs): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name - self.strings[alias] = self.get_str(alias.original) - - def visit_column(self, column): - ansisql.ANSICompiler.visit_column(self, column) - if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable: - self.strings[column] = self.strings[column] + "(+)" + + if asfrom: + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name + else: + return self.process(alias.original, **kwargs) def visit_insert(self, insert): """``INSERT`` s are required to have the primary keys be explicitly present. @@ -539,76 +575,35 @@ class OracleCompiler(ansisql.ANSICompiler): def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass - if getattr(select, '_oracle_visit', False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_compound_select(self, select) - return - - if select.limit is not None or select.offset is not None: - select._oracle_visit = True - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] - if not orderby: - orderby = select.oid_column - self.traverse(orderby) - orderby = self.strings[orderby] - class SelectVisitor(sql.NoColumnVisitor): - def visit_select(self, select): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - SelectVisitor().traverse(select) - limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) - else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) - self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] - else: - ansisql.ANSICompiler.visit_compound_select(self, select) - - def visit_select(self, select): + def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. """ - # TODO: put a real copy-container on Select and copy, or somehow make this - # not modify the Select statement - if self._select_state.get((select, 'visit'), False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_select(self, select) - return - - if select.limit is not None or select.offset is not None: - self._select_state[(select, 'visit')] = True + if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None): # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] + orderby = self.process(select._order_by_clause) if not orderby: orderby = select.oid_column self.traverse(orderby) - orderby = self.strings[orderby] - if not hasattr(select, '_oracle_visit'): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - select._oracle_visit = True + orderby = self.process(orderby) + + oldselect = select + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) + select._oracle_visit = True + limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) + if select._offset is not None: + limitselect.append_whereclause("ora_rn>%d" % select._offset) + if select._limit is not None: + limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset)) else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) - self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] + limitselect.append_whereclause("ora_rn<=%d" % select._limit) + return self.process(limitselect) else: - ansisql.ANSICompiler.visit_select(self, select) + return ansisql.ANSICompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" @@ -619,12 +614,6 @@ class OracleCompiler(ansisql.ANSICompiler): else: return super(OracleCompiler, self).for_update_clause(select) - def visit_binary(self, binary): - if binary.operator == '%': - self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right))) - else: - return ansisql.ANSICompiler.visit_binary(self, binary) - class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): @@ -639,22 +628,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not self.dialect.has_sequence(self.connection, sequence.name): + if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class OracleSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if self.dialect.has_sequence(self.connection, sequence.name): + if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["DUAL"]).compile(engine=self.connection) - return self.connection.execute_compiled(c).scalar() + c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection) + return self.connection.execute(c).scalar() def visit_sequence(self, seq): - return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar() + return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar() dialect = OracleDialect diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d3726fc1f..b192c4778 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,12 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, string, types, re, random, warnings +import re, random, warnings, operator -from sqlalchemy import util, sql, schema, ansisql, exceptions +from sqlalchemy import sql, schema, ansisql, exceptions from sqlalchemy.engine import base, default import sqlalchemy.types as sqltypes from sqlalchemy.databases import information_schema as ischema +from decimal import Decimal try: import mx.DateTime.DateTime as mxDateTime @@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class PGFloat(sqltypes.Float): def get_col_spec(self): if not self.precision: @@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float): else: return "FLOAT(%(precision)s)" % {'precision': self.precision} + class PGInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -47,74 +58,15 @@ class PGBigInteger(PGInteger): def get_col_spec(self): return "BIGINT" -class PG2DateTime(sqltypes.DateTime): +class PGDateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" -class PG1DateTime(sqltypes.DateTime): - def convert_bind_param(self, value, dialect): - if value is not None: - if isinstance(value, datetime.datetime): - seconds = float(str(value.second) + "." - + str(value.microsecond)) - mx_datetime = mxDateTime(value.year, value.month, value.day, - value.hour, value.minute, - seconds) - return dialect.dbapi.TimestampFromMx(mx_datetime) - return dialect.dbapi.TimestampFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - if value is None: - return None - second_parts = str(value.second).split(".") - seconds = int(second_parts[0]) - microseconds = int(float(second_parts[1])) - return datetime.datetime(value.year, value.month, value.day, - value.hour, value.minute, seconds, - microseconds) - - def get_col_spec(self): - return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PG2Date(sqltypes.Date): - def get_col_spec(self): - return "DATE" - -class PG1Date(sqltypes.Date): - def convert_bind_param(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - # this one doesnt seem to work with the "emulation" mode - if value is not None: - return dialect.dbapi.DateFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - return value - +class PGDate(sqltypes.Date): def get_col_spec(self): return "DATE" -class PG2Time(sqltypes.Time): - def get_col_spec(self): - return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PG1Time(sqltypes.Time): - def convert_bind_param(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - # this one doesnt seem to work with the "emulation" mode - if value is not None: - return psycopg.TimeFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - return value - +class PGTime(sqltypes.Time): def get_col_spec(self): return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" @@ -142,28 +94,55 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" -pg2_colspecs = { +class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): + def __init__(self, item_type): + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + + def dialect_impl(self, dialect): + impl = self.__class__.__new__(self.__class__) + impl.__dict__.update(self.__dict__) + impl.item_type = self.item_type.dialect_impl(dialect) + return impl + def convert_bind_param(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_bind_param(item, dialect) + return [convert_item(item) for item in value] + def convert_result_value(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_result_value(item, dialect) + # Could specialcase when item_type.convert_result_value is the default identity func + return [convert_item(item) for item in value] + def get_col_spec(self): + return self.item_type.get_col_spec() + '[]' + +colspecs = { sqltypes.Integer : PGInteger, sqltypes.Smallinteger : PGSmallInteger, sqltypes.Numeric : PGNumeric, sqltypes.Float : PGFloat, - sqltypes.DateTime : PG2DateTime, - sqltypes.Date : PG2Date, - sqltypes.Time : PG2Time, + sqltypes.DateTime : PGDateTime, + sqltypes.Date : PGDate, + sqltypes.Time : PGTime, sqltypes.String : PGString, sqltypes.Binary : PGBinary, sqltypes.Boolean : PGBoolean, sqltypes.TEXT : PGText, sqltypes.CHAR: PGChar, } -pg1_colspecs = pg2_colspecs.copy() -pg1_colspecs.update({ - sqltypes.DateTime : PG1DateTime, - sqltypes.Date : PG1Date, - sqltypes.Time : PG1Time - }) - -pg2_ischema_names = { + +ischema_names = { 'integer' : PGInteger, 'bigint' : PGBigInteger, 'smallint' : PGSmallInteger, @@ -175,24 +154,17 @@ pg2_ischema_names = { 'real' : PGFloat, 'inet': PGInet, 'double precision' : PGFloat, - 'timestamp' : PG2DateTime, - 'timestamp with time zone' : PG2DateTime, - 'timestamp without time zone' : PG2DateTime, - 'time with time zone' : PG2Time, - 'time without time zone' : PG2Time, - 'date' : PG2Date, - 'time': PG2Time, + 'timestamp' : PGDateTime, + 'timestamp with time zone' : PGDateTime, + 'timestamp without time zone' : PGDateTime, + 'time with time zone' : PGTime, + 'time without time zone' : PGTime, + 'date' : PGDate, + 'time': PGTime, 'bytea' : PGBinary, 'boolean' : PGBoolean, 'interval':PGInterval, } -pg1_ischema_names = pg2_ischema_names.copy() -pg1_ischema_names.update({ - 'timestamp with time zone' : PG1DateTime, - 'timestamp without time zone' : PG1DateTime, - 'date' : PG1Date, - 'time' : PG1Time - }) def descriptor(): return {'name':'postgres', @@ -206,11 +178,11 @@ def descriptor(): class PGExecutionContext(default.DefaultExecutionContext): - def is_select(self): - return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I) - + def _is_server_side(self): + return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I) + def create_cursor(self): - if self.dialect.server_side_cursors and self.is_select(): + if self._is_server_side(): # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html ident = "c" + hex(random.randint(0, 65535))[2:] @@ -219,7 +191,7 @@ class PGExecutionContext(default.DefaultExecutionContext): return self.connection.connection.cursor() def get_result_proxy(self): - if self.dialect.server_side_cursors and self.is_select(): + if self._is_server_side(): return base.BufferedRowResultProxy(self) else: return base.ResultProxy(self) @@ -242,31 +214,18 @@ class PGDialect(ansisql.ANSIDialect): ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors - if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'): - self.version = 2 - else: - self.version = 1 self.use_information_schema = use_information_schema self.paramstyle = 'pyformat' def dbapi(cls): - try: - import psycopg2 as psycopg - except ImportError, e: - try: - import psycopg - except ImportError, e2: - raise e + import psycopg2 as psycopg return psycopg dbapi = classmethod(dbapi) def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) if opts.has_key('port'): - if self.version == 2: - opts['port'] = int(opts['port']) - else: - opts['port'] = str(opts['port']) + opts['port'] = int(opts['port']) opts.update(url.query) return ([], opts) @@ -278,10 +237,7 @@ class PGDialect(ansisql.ANSIDialect): return 63 def type_descriptor(self, typeobj): - if self.version == 2: - return sqltypes.adapt_type(typeobj, pg2_colspecs) - else: - return sqltypes.adapt_type(typeobj, pg1_colspecs) + return sqltypes.adapt_type(typeobj, colspecs) def compiler(self, statement, bindparams, **kwargs): return PGCompiler(self, statement, bindparams, **kwargs) @@ -292,8 +248,36 @@ class PGDialect(ansisql.ANSIDialect): def schemadropper(self, *args, **kwargs): return PGSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, connection, **kwargs): - return PGDefaultRunner(connection, **kwargs) + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions + # Must find out a way how to make the dbapi not open a transaction. + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) + return [row[0] for row in resultset] + + def defaultrunner(self, context, **kwargs): + return PGDefaultRunner(context, **kwargs) def preparer(self): return PGIdentifierPreparer(self) @@ -351,14 +335,9 @@ class PGDialect(ansisql.ANSIDialect): else: return False - def reflecttable(self, connection, table): - if self.version == 2: - ischema_names = pg2_ischema_names - else: - ischema_names = pg1_ischema_names - + def reflecttable(self, connection, table, include_columns): if self.use_information_schema: - ischema.reflecttable(connection, table, ischema_names) + ischema.reflecttable(connection, table, include_columns, ischema_names) else: preparer = self.identifier_preparer if table.schema is not None: @@ -387,7 +366,7 @@ class PGDialect(ansisql.ANSIDialect): ORDER BY a.attnum """ % schema_where_clause - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) + s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) c = connection.execute(s, table_name=table.name, schema=table.schema) rows = c.fetchall() @@ -398,9 +377,13 @@ class PGDialect(ansisql.ANSIDialect): domains = self._load_domains(connection) for name, format_type, default, notnull, attnum, table_oid in rows: + if include_columns and name not in include_columns: + continue + ## strip (30) from character varying(30) - attype = re.search('([^\(]+)', format_type).group(1) + attype = re.search('([^\([]+)', format_type).group(1) nullable = not notnull + is_array = format_type.endswith('[]') try: charlen = re.search('\(([\d,]+)\)', format_type).group(1) @@ -453,6 +436,8 @@ class PGDialect(ansisql.ANSIDialect): if coltype: coltype = coltype(*args, **kwargs) + if is_array: + coltype = PGArray(coltype) else: warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name))) coltype = sqltypes.NULLTYPE @@ -517,7 +502,6 @@ class PGDialect(ansisql.ANSIDialect): table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname)) def _load_domains(self, connection): - ## Load data types for domains: SQL_DOMAINS = """ SELECT t.typname as "name", @@ -554,49 +538,46 @@ class PGDialect(ansisql.ANSIDialect): - class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): - # all column primary key inserts must be explicitly present - if column.primary_key: - parameters[column.key] = None + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : '%%' + } + ) - def visit_insert_sequence(self, column, sequence, parameters): - """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures - that the column is present in the generated column list""" - parameters.setdefault(column.key, None) + def uses_sequences_for_inserts(self): + return True def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT ALL" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def visit_select_precolumns(self, select): - if select.distinct: - if type(select.distinct) == bool: + def get_select_precolumns(self, select): + if select._distinct: + if type(select._distinct) == bool: return "DISTINCT " - if type(select.distinct) == list: + if type(select._distinct) == list: dist_set = "DISTINCT ON (" - for col in select.distinct: + for col in select._distinct: dist_set += self.strings[col] + ", " dist_set = dist_set[:-2] + ") " return dist_set - return "DISTINCT ON (" + str(select.distinct) + ") " + return "DISTINCT ON (" + str(select._distinct) + ") " else: return "" - def binary_operator_string(self, binary): - if isinstance(binary.type, sqltypes.String) and binary.operator == '+': - return '||' - elif binary.operator == '%': - return '%%' + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) + return super(PGCompiler, self).for_update_clause(select) class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): @@ -617,13 +598,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)): + if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class PGSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)): + if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() @@ -632,7 +613,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): - return self.connection.execute_text("select %s" % column.default.arg).scalar() + return self.connection.execute("select %s" % column.default.arg).scalar() elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -641,7 +622,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.connection.execute_text(exc).scalar() + return self.connection.execute(exc).scalar() return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 816b1b76a..725ea23e2 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -5,9 +5,9 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, types, re +import re -from sqlalchemy import sql, engine, schema, ansisql, exceptions, pool, PassiveDefault +from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault import sqlalchemy.engine.default as default import sqlalchemy.types as sqltypes import datetime,time, warnings @@ -126,6 +126,7 @@ colspecs = { pragma_names = { 'INTEGER' : SLInteger, + 'INT' : SLInteger, 'SMALLINT' : SLSmallInteger, 'VARCHAR' : SLString, 'CHAR' : SLChar, @@ -150,8 +151,9 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): if self.compiled.isinsert: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - - super(SQLiteExecutionContext, self).post_exec() + + def is_select(self): + return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None class SQLiteDialect(ansisql.ANSIDialect): @@ -233,7 +235,7 @@ class SQLiteDialect(ansisql.ANSIDialect): return (row is not None) - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {}) found_table = False while True: @@ -244,6 +246,8 @@ class SQLiteDialect(ansisql.ANSIDialect): found_table = True (name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5]) name = re.sub(r'^\"|\"$', '', name) + if include_columns and name not in include_columns: + continue match = re.match(r'(\w+)(\(.*?\))?', type) if match: coltype = match.group(1) @@ -253,7 +257,12 @@ class SQLiteDialect(ansisql.ANSIDialect): args = '' #print "coltype: " + repr(coltype) + " args: " + repr(args) - coltype = pragma_names.get(coltype, SLString) + try: + coltype = pragma_names[coltype] + except KeyError: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name))) + coltype = sqltypes.NULLTYPE + if args is not None: args = re.findall(r'(\d+)', args) #print "args! " +repr(args) @@ -318,21 +327,21 @@ class SQLiteDialect(ansisql.ANSIDialect): class SQLiteCompiler(ansisql.ANSICompiler): def visit_cast(self, cast): if self.dialect.supports_cast: - super(SQLiteCompiler, self).visit_cast(cast) + return super(SQLiteCompiler, self).visit_cast(cast) else: if len(self.select_stack): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) - self.strings[cast] = self.strings[cast.clause] + return self.process(cast.clause) def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT -1" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) else: text += " OFFSET 0" return text @@ -341,12 +350,6 @@ class SQLiteCompiler(ansisql.ANSICompiler): # sqlite has no "FOR UPDATE" AFAICT return '' - def binary_operator_string(self, binary): - if isinstance(binary.type, sqltypes.String) and binary.operator == '+': - return '||' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) - class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): |