summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
commited4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch)
treec1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/databases
parent3a8e235af64e36b3b711df1f069d32359fe6c967 (diff)
downloadsqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/databases')
-rw-r--r--lib/sqlalchemy/databases/firebird.py57
-rw-r--r--lib/sqlalchemy/databases/information_schema.py13
-rw-r--r--lib/sqlalchemy/databases/informix.py63
-rw-r--r--lib/sqlalchemy/databases/mssql.py113
-rw-r--r--lib/sqlalchemy/databases/mysql.py107
-rw-r--r--lib/sqlalchemy/databases/oracle.py255
-rw-r--r--lib/sqlalchemy/databases/postgres.py293
-rw-r--r--lib/sqlalchemy/databases/sqlite.py41
8 files changed, 482 insertions, 460 deletions
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index a02781c84..07f07644f 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -5,15 +5,11 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, types
+import warnings
-from sqlalchemy import util
+from sqlalchemy import util, sql, schema, ansisql, exceptions
import sqlalchemy.engine.default as default
-import sqlalchemy.sql as sql
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
_initialized_kb = False
@@ -176,7 +172,7 @@ class FBDialect(ansisql.ANSIDialect):
else:
return False
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
#TODO: map these better
column_func = {
14 : lambda r: sqltypes.String(r['FLEN']), # TEXT
@@ -254,11 +250,20 @@ class FBDialect(ansisql.ANSIDialect):
while row:
name = row['FNAME']
- args = [lower_if_possible(name)]
+ python_name = lower_if_possible(name)
+ if include_columns and python_name not in include_columns:
+ continue
+ args = [python_name]
kw = {}
# get the data types and lengths
- args.append(column_func[row['FTYPE']](row))
+ coltype = column_func.get(row['FTYPE'], None)
+ if coltype is None:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (str(row['FTYPE']), name)))
+ coltype = sqltypes.NULLTYPE
+ else:
+ coltype = coltype(row)
+ args.append(coltype)
# is it a primary key?
kw['primary_key'] = name in pkfields
@@ -301,39 +306,39 @@ class FBDialect(ansisql.ANSIDialect):
class FBCompiler(ansisql.ANSICompiler):
"""Firebird specific idiosincrasies"""
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
# Override to not use the AS keyword which FB 1.5 does not like
- self.froms[alias] = self.get_from_text(alias.original) + " " + self.preparer.format_alias(alias)
- self.strings[alias] = self.get_str(alias.original)
+ if asfrom:
+ return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias)
+ else:
+ return self.process(alias.original, asfrom=True)
def visit_function(self, func):
if len(func.clauses):
- super(FBCompiler, self).visit_function(func)
+ return super(FBCompiler, self).visit_function(func)
else:
- self.strings[func] = func.name
+ return func.name
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ def uses_sequences_for_inserts(self):
+ return True
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right
after the ``SELECT``...
"""
result = ""
- if select.limit:
- result += " FIRST %d " % select.limit
- if select.offset:
- result +=" SKIP %d " % select.offset
- if select.distinct:
+ if select._limit:
+ result += " FIRST %d " % select._limit
+ if select._offset:
+ result +=" SKIP %d " % select._offset
+ if select._distinct:
result += " DISTINCT "
return result
def limit_clause(self, select):
- """Already taken care of in the `visit_select_precolumns` method."""
+ """Already taken care of in the `get_select_precolumns` method."""
return ""
@@ -364,7 +369,7 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper):
class FBDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.connection)
+ c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection)
return self.connection.execute_compiled(c).scalar()
def visit_sequence(self, seq):
diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py
index 81c44dcaa..93f47de15 100644
--- a/lib/sqlalchemy/databases/information_schema.py
+++ b/lib/sqlalchemy/databases/information_schema.py
@@ -1,4 +1,6 @@
-from sqlalchemy import sql, schema, exceptions, select, MetaData, Table, Column, String, Integer
+import sqlalchemy.sql as sql
+import sqlalchemy.exceptions as exceptions
+from sqlalchemy import select, MetaData, Table, Column, String, Integer
from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint
ischema = MetaData()
@@ -96,8 +98,7 @@ class ISchema(object):
return self.cache[name]
-def reflecttable(connection, table, ischema_names):
-
+def reflecttable(connection, table, include_columns, ischema_names):
key_constraints = pg_key_constraints
if table.schema is not None:
@@ -128,7 +129,9 @@ def reflecttable(connection, table, ischema_names):
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
-
+ if include_columns and name not in include_columns:
+ continue
+
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
@@ -139,7 +142,7 @@ def reflecttable(connection, table, ischema_names):
colargs= []
if default is not None:
colargs.append(PassiveDefault(sql.text(default)))
- table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+ table.append_column(Column(name, coltype, nullable=nullable, *colargs))
if not found_table:
raise exceptions.NoSuchTableError(table.name)
diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py
index 2fb508280..f3a6cf60e 100644
--- a/lib/sqlalchemy/databases/informix.py
+++ b/lib/sqlalchemy/databases/informix.py
@@ -5,20 +5,11 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import datetime, warnings
-import sys, StringIO, string , random
-import datetime
-from decimal import Decimal
-
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
+from sqlalchemy import sql, schema, ansisql, exceptions, pool
import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
-import sqlalchemy.pool as pool
# for offset
@@ -128,7 +119,7 @@ class InfoBoolean(sqltypes.Boolean):
elif value is None:
return None
else:
- return value and True or False
+ return value and True or False
colspecs = {
@@ -262,7 +253,7 @@ class InfoDialect(ansisql.ANSIDialect):
cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() )
return bool( cursor.fetchone() is not None )
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() )
rows = c.fetchall()
if not rows :
@@ -289,6 +280,10 @@ class InfoDialect(ansisql.ANSIDialect):
raise exceptions.NoSuchTableError(table.name)
for name , colattr , collength , default , colno in rows:
+ name = name.lower()
+ if include_columns and name not in include_columns:
+ continue
+
# in 7.31, coltype = 0x000
# ^^-- column type
# ^-- 1 not null , 0 null
@@ -306,14 +301,16 @@ class InfoDialect(ansisql.ANSIDialect):
scale = 0
coltype = InfoNumeric(precision, scale)
else:
- coltype = ischema_names.get(coltype)
+ try:
+ coltype = ischema_names[coltype]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
+ coltype = sqltypes.NULLTYPE
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
- name = name.lower()
-
table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs))
# FK
@@ -372,20 +369,20 @@ class InfoCompiler(ansisql.ANSICompiler):
def default_from(self):
return " from systables where tabname = 'systables' "
- def visit_select_precolumns( self , select ):
- s = select.distinct and "DISTINCT " or ""
+ def get_select_precolumns( self , select ):
+ s = select._distinct and "DISTINCT " or ""
# only has limit
- if select.limit:
- off = select.offset or 0
- s += " FIRST %s " % ( select.limit + off )
+ if select._limit:
+ off = select._offset or 0
+ s += " FIRST %s " % ( select._limit + off )
else:
s += ""
return s
def visit_select(self, select):
- if select.offset:
- self.offset = select.offset
- self.limit = select.limit or 0
+ if select._offset:
+ self.offset = select._offset
+ self.limit = select._limit or 0
# the column in order by clause must in select too
def __label( c ):
@@ -393,13 +390,14 @@ class InfoCompiler(ansisql.ANSICompiler):
return c._label.lower()
except:
return ''
-
+
+ # TODO: dont modify the original select, generate a new one
a = [ __label(c) for c in select._raw_columns ]
for c in select.order_by_clause.clauses:
if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
select.append_column( c )
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select)
def limit_clause(self, select):
return ""
@@ -414,23 +412,20 @@ class InfoCompiler(ansisql.ANSICompiler):
def visit_function( self , func ):
if func.name.lower() == 'current_date':
- self.strings[func] = "today"
+ return "today"
elif func.name.lower() == 'current_time':
- self.strings[func] = "CURRENT HOUR TO SECOND"
+ return "CURRENT HOUR TO SECOND"
elif func.name.lower() in ( 'current_timestamp' , 'now' ):
- self.strings[func] = "CURRENT YEAR TO SECOND"
+ return "CURRENT YEAR TO SECOND"
else:
- ansisql.ANSICompiler.visit_function( self , func )
+ return ansisql.ANSICompiler.visit_function( self , func )
def visit_clauselist(self, list):
try:
li = [ c for c in list.clauses if c.name != 'oid' ]
except:
li = [ c for c in list.clauses ]
- if list.parens:
- self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in li] if s is not None ], ', ') + ")"
- else:
- self.strings[list] = string.join([s for s in [self.get_str(c) for c in li] if s is not None], ', ')
+ return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, first_pk=False):
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index ba1c0fd9d..206291404 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -25,7 +25,7 @@
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
-* ``select.limit`` implemented as ``SELECT TOP n``
+* ``select._limit`` implemented as ``SELECT TOP n``
Known issues / TODO:
@@ -39,16 +39,11 @@ Known issues / TODO:
"""
-import sys, StringIO, string, types, re, datetime, random
+import datetime, random, warnings
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+from sqlalchemy import sql, schema, ansisql, exceptions
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
-
+from sqlalchemy.engine import default
class MSNumeric(sqltypes.Numeric):
def convert_result_value(self, value, dialect):
@@ -500,7 +495,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
row = c.fetchone()
return row is not None
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
import sqlalchemy.databases.information_schema as ischema
# Get base columns
@@ -532,16 +527,22 @@ class MSSQLDialect(ansisql.ANSIDialect):
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
+ if include_columns and name not in include_columns:
+ continue
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
args.append(a)
- coltype = self.ischema_names[type]
+ coltype = self.ischema_names.get(type, None)
if coltype == MSString and charlen == -1:
coltype = MSText()
else:
- if coltype == MSNVarchar and charlen == -1:
+ if coltype is None:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name)))
+ coltype = sqltypes.NULLTYPE
+
+ elif coltype == MSNVarchar and charlen == -1:
charlen = None
coltype = coltype(*args)
colargs= []
@@ -812,12 +813,12 @@ class MSSQLCompiler(ansisql.ANSICompiler):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
- s = select.distinct and "DISTINCT " or ""
- if select.limit:
- s += "TOP %s " % (select.limit,)
- if select.offset:
+ s = select._distinct and "DISTINCT " or ""
+ if select._limit:
+ s += "TOP %s " % (select._limit,)
+ if select._offset:
raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
@@ -825,49 +826,50 @@ class MSSQLCompiler(ansisql.ANSICompiler):
# Limit in mssql is after the select keyword
return ""
- def visit_table(self, table):
+ def _schema_aliased_table(self, table):
+ if getattr(table, 'schema', None) is not None:
+ if not self.tablealiases.has_key(table):
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_table(self, table, mssql_aliased=False, **kwargs):
+ if mssql_aliased:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
# alias schema-qualified tables
- if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
- alias = table.alias()
- self.tablealiases[table] = alias
- self.traverse(alias)
- self.froms[('alias', table)] = self.froms[table]
- for c in alias.c:
- self.traverse(c)
- self.traverse(alias.oid_column)
- self.tablealiases[alias] = self.froms[table]
- self.froms[table] = self.froms[alias]
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=True, **kwargs)
else:
- super(MSSQLCompiler, self).visit_table(table)
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
- def visit_alias(self, alias):
+ def visit_alias(self, alias, **kwargs):
# translate for schema-qualified table aliases
- if self.froms.has_key(('alias', alias.original)):
- self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
- self.strings[alias] = ""
- else:
- super(MSSQLCompiler, self).visit_alias(alias)
+ self.tablealiases[alias.original] = alias
+ return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
def visit_column(self, column):
- # translate for schema-qualified table aliases
- super(MSSQLCompiler, self).visit_column(column)
- if column.table is not None and self.tablealiases.has_key(column.table):
- self.strings[column] = \
- self.strings[self.tablealiases[column.table].corresponding_column(column)]
+ if column.table is not None:
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ return self.process(t.corresponding_column(column))
+ return super(MSSQLCompiler, self).visit_column(column)
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=':
- binary.left, binary.right = binary.right, binary.left
- super(MSSQLCompiler, self).visit_binary(binary)
-
- def visit_select(self, select):
- # label function calls, so they return a name in cursor.description
- for i,c in enumerate(select._raw_columns):
- if isinstance(c, sql._Function):
- select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
+ if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
+ return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+ else:
+ return super(MSSQLCompiler, self).visit_binary(binary)
- super(MSSQLCompiler, self).visit_select(select)
+ def label_select_column(self, select, column):
+ if isinstance(column, sql._Function):
+ return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
+ else:
+ return super(MSSQLCompiler, self).label_select_column(select, column)
function_rewrites = {'current_date': 'getdate',
'length': 'len',
@@ -881,10 +883,10 @@ class MSSQLCompiler(ansisql.ANSICompiler):
return ''
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.process(select._order_by_clause)
# MSSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not select.is_subquery or select.limit):
+ if order_by and (not self.is_subquery(select) or select._limit):
return " ORDER BY " + order_by
else:
return ""
@@ -916,10 +918,12 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote_identifier(index.name)))
+ self.preparer.quote_identifier(index.table.name),
+ self.preparer.quote_identifier(index.name)
+ ))
self.execute()
+
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
# TODO: does ms-sql have standalone sequences ?
pass
@@ -940,4 +944,3 @@ dialect = MSSQLDialect
-
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index bac0e5e12..26800e32b 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, datetime, inspect, warnings, weakref
+import re, datetime, inspect, warnings, weakref, operator
from sqlalchemy import sql, schema, ansisql
from sqlalchemy.engine import default
@@ -12,13 +12,13 @@ import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
import sqlalchemy.util as util
from array import array as _array
+from decimal import Decimal
try:
from threading import Lock
except ImportError:
from dummy_threading import Lock
-
RESERVED_WORDS = util.Set(
['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
@@ -60,7 +60,6 @@ RESERVED_WORDS = util.Set(
'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
'read_only', 'read_write', # 5.1
])
-
_per_connection_mutex = Lock()
class _NumericType(object):
@@ -137,7 +136,7 @@ class _StringType(object):
class MSNumeric(sqltypes.Numeric, _NumericType):
"""MySQL NUMERIC type"""
- def __init__(self, precision = 10, length = 2, **kw):
+ def __init__(self, precision = 10, length = 2, asdecimal=True, **kw):
"""Construct a NUMERIC.
precision
@@ -157,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
"""
_NumericType.__init__(self, **kw)
- sqltypes.Numeric.__init__(self, precision, length)
-
+ sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal)
+
def get_col_spec(self):
if self.precision is None:
return self._extend("NUMERIC")
else:
return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class MSDecimal(MSNumeric):
"""MySQL DECIMAL type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DECIMAL.
precision
@@ -187,7 +195,7 @@ class MSDecimal(MSNumeric):
underlying database API, which continue to be numeric.
"""
- super(MSDecimal, self).__init__(precision, length, **kw)
+ super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is None:
@@ -200,7 +208,7 @@ class MSDecimal(MSNumeric):
class MSDouble(MSNumeric):
"""MySQL DOUBLE type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DOUBLE.
precision
@@ -222,7 +230,7 @@ class MSDouble(MSNumeric):
if ((precision is None and length is not None) or
(precision is not None and length is None)):
raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.")
- super(MSDouble, self).__init__(precision, length, **kw)
+ super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is not None and self.length is not None:
@@ -235,7 +243,7 @@ class MSDouble(MSNumeric):
class MSFloat(sqltypes.Float, _NumericType):
"""MySQL FLOAT type"""
- def __init__(self, precision=10, length=None, **kw):
+ def __init__(self, precision=10, length=None, asdecimal=False, **kw):
"""Construct a FLOAT.
precision
@@ -257,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType):
if length is not None:
self.length=length
_NumericType.__init__(self, **kw)
- sqltypes.Float.__init__(self, precision)
+ sqltypes.Float.__init__(self, precision, asdecimal=asdecimal)
def get_col_spec(self):
if hasattr(self, 'length') and self.length is not None:
@@ -267,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType):
else:
return self._extend("FLOAT")
+ def convert_bind_param(self, value, dialect):
+ return value
+
+
class MSInteger(sqltypes.Integer, _NumericType):
"""MySQL INTEGER type"""
@@ -955,7 +967,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
+
+ def is_select(self):
+ return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None
+
class MySQLDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
@@ -1044,6 +1059,27 @@ class MySQLDialect(ansisql.ANSIDialect):
except:
pass
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("XA RECOVER"))
+ return [row['data'][0:row['gtrid_length']] for row in resultset]
+
def is_disconnect(self, e):
return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
@@ -1088,7 +1124,7 @@ class MySQLDialect(ansisql.ANSIDialect):
version.append(n)
return tuple(version)
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
"""Load column definitions from the server."""
decode_from = self._detect_charset(connection)
@@ -1111,6 +1147,9 @@ class MySQLDialect(ansisql.ANSIDialect):
# leave column names as unicode
name = name.decode(decode_from)
+
+ if include_columns and name not in include_columns:
+ continue
match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
col_type = match.group(1)
@@ -1118,7 +1157,11 @@ class MySQLDialect(ansisql.ANSIDialect):
extra_1 = match.group(3)
extra_2 = match.group(4)
- coltype = ischema_names.get(col_type, MSString)
+ try:
+ coltype = ischema_names[col_type]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name)))
+ coltype = sqltypes.NULLTYPE
kw = {}
if extra_1 is not None:
@@ -1156,7 +1199,6 @@ class MySQLDialect(ansisql.ANSIDialect):
if not row:
raise exceptions.NoSuchTableError(table.fullname)
desc = row[1].strip()
- row.close()
tabletype = ''
lastparen = re.search(r'\)[^\)]*\Z', desc)
@@ -1223,7 +1265,6 @@ class MySQLDialect(ansisql.ANSIDialect):
cs = True
else:
cs = row[1] in ('0', 'OFF' 'off')
- row.close()
cache['lower_case_table_names'] = cs
self.per_connection[raw_connection] = cache
return cache.get('lower_case_table_names')
@@ -1266,14 +1307,21 @@ class _MySQLPythonRowProxy(object):
class MySQLCompiler(ansisql.ANSICompiler):
- def visit_cast(self, cast):
-
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
+ operator.mod : '%%'
+ }
+ )
+
+ def visit_cast(self, cast, **kwargs):
if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
- return super(MySQLCompiler, self).visit_cast(cast)
+ return super(MySQLCompiler, self).visit_cast(cast, **kwargs)
else:
# so just skip the CAST altogether for now.
# TODO: put whatever MySQL does for CAST here.
- self.strings[cast] = self.strings[cast.clause]
+ return self.process(cast.clause)
def for_update_clause(self, select):
if select.for_update == 'read':
@@ -1283,20 +1331,15 @@ class MySQLCompiler(ansisql.ANSICompiler):
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
- # striaght from the MySQL docs, I kid you not
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ # straight from the MySQL docs, I kid you not
text += " \n LIMIT 18446744073709551615"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def binary_operator_string(self, binary):
- if binary.operator == '%':
- return '%%'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index 9d7d6a112..d3aa2e268 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -5,9 +5,9 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, re, warnings
+import re, warnings, operator
-from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
+from sqlalchemy import util, sql, schema, ansisql, exceptions, logging
from sqlalchemy.engine import default, base
import sqlalchemy.types as sqltypes
@@ -88,8 +88,11 @@ class OracleText(sqltypes.TEXT):
def convert_result_value(self, value, dialect):
if value is None:
return None
- else:
+ elif hasattr(value, 'read'):
+ # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str
return super(OracleText, self).convert_result_value(value.read(), dialect)
+ else:
+ return super(OracleText, self).convert_result_value(value, dialect)
class OracleRaw(sqltypes.Binary):
@@ -178,25 +181,31 @@ class OracleExecutionContext(default.DefaultExecutionContext):
super(OracleExecutionContext, self).pre_exec()
if self.dialect.auto_setinputsizes:
self.set_input_sizes()
+ if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list):
+ for key in self.compiled_parameters:
+ (bindparam, name, value) = self.compiled_parameters.get_parameter(key)
+ if bindparam.isoutparam:
+ dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+ if not hasattr(self, 'out_parameters'):
+ self.out_parameters = {}
+ self.out_parameters[name] = self.cursor.var(dbtype)
+ self.parameters[name] = self.out_parameters[name]
def get_result_proxy(self):
+ if hasattr(self, 'out_parameters'):
+ if self.compiled_parameters is not None:
+ for k in self.out_parameters:
+ type = self.compiled_parameters.get_type(k)
+ self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect)
+ else:
+ for k in self.out_parameters:
+ self.out_parameters[k] = self.out_parameters[k].getvalue()
+
if self.cursor.description is not None:
- if self.dialect.auto_convert_lobs and self.typemap is None:
- typemap = {}
- binary = False
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- binary = True
- typemap[column[0].lower()] = OracleBinary()
- self.typemap = typemap
- if binary:
+ for column in self.cursor.description:
+ type_code = column[1]
+ if type_code in self.dialect.ORACLE_BINARY_TYPES:
return base.BufferedColumnResultProxy(self)
- else:
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- return base.BufferedColumnResultProxy(self)
return base.ResultProxy(self)
@@ -208,11 +217,26 @@ class OracleDialect(ansisql.ANSIDialect):
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
self.auto_convert_lobs = auto_convert_lobs
+
if self.dbapi is not None:
self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
else:
self.ORACLE_BINARY_TYPES = []
+ def dbapi_type_map(self):
+ if self.dbapi is None or not self.auto_convert_lobs:
+ return {}
+ else:
+ return {
+ self.dbapi.NUMBER: OracleInteger(),
+ self.dbapi.CLOB: OracleText(),
+ self.dbapi.BLOB: OracleBinary(),
+ self.dbapi.STRING: OracleString(),
+ self.dbapi.TIMESTAMP: OracleTimestamp(),
+ self.dbapi.BINARY: OracleRaw(),
+ datetime.datetime: OracleDate()
+ }
+
def dbapi(cls):
import cx_Oracle
return cx_Oracle
@@ -251,7 +275,7 @@ class OracleDialect(ansisql.ANSIDialect):
return 30
def oid_column_name(self, column):
- if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select):
+ if not isinstance(column.table, (sql.TableClause, sql.Select)):
return None
else:
return "rowid"
@@ -341,7 +365,7 @@ class OracleDialect(ansisql.ANSIDialect):
return name, owner, dblink
raise
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
preparer = self.identifier_preparer
if not preparer.should_quote(table):
name = table.name.upper()
@@ -363,6 +387,13 @@ class OracleDialect(ansisql.ANSIDialect):
#print "ROW:" , row
(colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+ # if name comes back as all upper, assume its case folded
+ if (colname.upper() == colname):
+ colname = colname.lower()
+
+ if include_columns and colname not in include_columns:
+ continue
+
# INTEGER if the scale is 0 and precision is null
# NUMBER if the scale and precision are both null
# NUMBER(9,2) if the precision is 9 and the scale is 2
@@ -382,16 +413,13 @@ class OracleDialect(ansisql.ANSIDialect):
try:
coltype = ischema_names[coltype]
except KeyError:
- raise exceptions.AssertionError("Can't get coltype for type '%s' on colname '%s'" % (coltype, colname))
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname)))
+ coltype = sqltypes.NULLTYPE
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
- # if name comes back as all upper, assume its case folded
- if (colname.upper() == colname):
- colname = colname.lower()
-
table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
if not len(table.columns):
@@ -458,16 +486,27 @@ class OracleDialect(ansisql.ANSIDialect):
OracleDialect.logger = logging.class_logger(OracleDialect)
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = 'outer_join_column'
+ def __init__(self, column):
+ self.column = column
+
class OracleCompiler(ansisql.ANSICompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
the use_ansi flag is False.
"""
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+ }
+ )
+
def __init__(self, *args, **kwargs):
super(OracleCompiler, self).__init__(*args, **kwargs)
- # we have to modify SELECT objects a little bit, so store state here
- self._select_state = {}
+ self.__wheres = {}
def default_from(self):
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
@@ -480,49 +519,46 @@ class OracleCompiler(ansisql.ANSICompiler):
def apply_function_parens(self, func):
return len(func.clauses) > 0
- def visit_join(self, join):
+ def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return ansisql.ANSICompiler.visit_join(self, join)
-
- self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
- where = self.wheres.get(join.left, None)
+ return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+
+ (where, parentjoin) = self.__wheres.get(join, (None, None))
+
+ class VisitOn(sql.ClauseVisitor):
+ def visit_binary(s, binary):
+ if binary.operator == operator.eq:
+ if binary.left.table is join.right:
+ binary.left = _OuterJoinColumn(binary.left)
+ elif binary.right.table is join.right:
+ binary.right = _OuterJoinColumn(binary.right)
+
if where is not None:
- self.wheres[join] = sql.and_(where, join.onclause)
+ self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
else:
- self.wheres[join] = join.onclause
-# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
- self.strings[join] = self.froms[join]
-
- if join.isouter:
- # if outer join, push on the right side table as the current "outertable"
- self._outertable = join.right
-
- # now re-visit the onclause, which will be used as a where clause
- # (the first visit occured via the Join object itself right before it called visit_join())
- self.traverse(join.onclause)
-
- self._outertable = None
-
- self.wheres[join].accept_visitor(self)
+ self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join)
- def visit_insert_sequence(self, column, sequence, parameters):
- """This is the `sequence` equivalent to ``ANSICompiler``'s
- `visit_insert_column_default` which ensures that the column is
- present in the generated column list.
- """
-
- parameters.setdefault(column.key, None)
+ return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+ def get_whereclause(self, f):
+ if f in self.__wheres:
+ return self.__wheres[f][0]
+ else:
+ return None
+
+ def visit_outer_join_column(self, vc):
+ return self.process(vc.column) + "(+)"
+
+ def uses_sequences_for_inserts(self):
+ return True
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
-
- self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
- self.strings[alias] = self.get_str(alias.original)
-
- def visit_column(self, column):
- ansisql.ANSICompiler.visit_column(self, column)
- if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
- self.strings[column] = self.strings[column] + "(+)"
+
+ if asfrom:
+ return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name
+ else:
+ return self.process(alias.original, **kwargs)
def visit_insert(self, insert):
"""``INSERT`` s are required to have the primary keys be explicitly present.
@@ -539,76 +575,35 @@ class OracleCompiler(ansisql.ANSICompiler):
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+ pass
- if getattr(select, '_oracle_visit', False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_compound_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- select._oracle_visit = True
- # to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
- if not orderby:
- orderby = select.oid_column
- self.traverse(orderby)
- orderby = self.strings[orderby]
- class SelectVisitor(sql.NoColumnVisitor):
- def visit_select(self, select):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- SelectVisitor().traverse(select)
- limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
- else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
- else:
- ansisql.ANSICompiler.visit_compound_select(self, select)
-
- def visit_select(self, select):
+ def visit_select(self, select, **kwargs):
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
- # TODO: put a real copy-container on Select and copy, or somehow make this
- # not modify the Select statement
- if self._select_state.get((select, 'visit'), False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- self._select_state[(select, 'visit')] = True
+ if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
# to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
+ orderby = self.process(select._order_by_clause)
if not orderby:
orderby = select.oid_column
self.traverse(orderby)
- orderby = self.strings[orderby]
- if not hasattr(select, '_oracle_visit'):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- select._oracle_visit = True
+ orderby = self.process(orderby)
+
+ oldselect = select
+ select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
+ select._oracle_visit = True
+
limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
+ if select._offset is not None:
+ limitselect.append_whereclause("ora_rn>%d" % select._offset)
+ if select._limit is not None:
+ limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
+ limitselect.append_whereclause("ora_rn<=%d" % select._limit)
+ return self.process(limitselect)
else:
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""
@@ -619,12 +614,6 @@ class OracleCompiler(ansisql.ANSICompiler):
else:
return super(OracleCompiler, self).for_update_clause(select)
- def visit_binary(self, binary):
- if binary.operator == '%':
- self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right)))
- else:
- return ansisql.ANSICompiler.visit_binary(self, binary)
-
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -639,22 +628,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class OracleSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["DUAL"]).compile(engine=self.connection)
- return self.connection.execute_compiled(c).scalar()
+ c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
+ return self.connection.execute(c).scalar()
def visit_sequence(self, seq):
- return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
+ return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
dialect = OracleDialect
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index d3726fc1f..b192c4778 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,12 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, string, types, re, random, warnings
+import re, random, warnings, operator
-from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy import sql, schema, ansisql, exceptions
from sqlalchemy.engine import base, default
import sqlalchemy.types as sqltypes
from sqlalchemy.databases import information_schema as ischema
+from decimal import Decimal
try:
import mx.DateTime.DateTime as mxDateTime
@@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric):
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class PGFloat(sqltypes.Float):
def get_col_spec(self):
if not self.precision:
@@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float):
else:
return "FLOAT(%(precision)s)" % {'precision': self.precision}
+
class PGInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
@@ -47,74 +58,15 @@ class PGBigInteger(PGInteger):
def get_col_spec(self):
return "BIGINT"
-class PG2DateTime(sqltypes.DateTime):
+class PGDateTime(sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-class PG1DateTime(sqltypes.DateTime):
- def convert_bind_param(self, value, dialect):
- if value is not None:
- if isinstance(value, datetime.datetime):
- seconds = float(str(value.second) + "."
- + str(value.microsecond))
- mx_datetime = mxDateTime(value.year, value.month, value.day,
- value.hour, value.minute,
- seconds)
- return dialect.dbapi.TimestampFromMx(mx_datetime)
- return dialect.dbapi.TimestampFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- if value is None:
- return None
- second_parts = str(value.second).split(".")
- seconds = int(second_parts[0])
- microseconds = int(float(second_parts[1]))
- return datetime.datetime(value.year, value.month, value.day,
- value.hour, value.minute, seconds,
- microseconds)
-
- def get_col_spec(self):
- return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG2Date(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
-class PG1Date(sqltypes.Date):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return dialect.dbapi.DateFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGDate(sqltypes.Date):
def get_col_spec(self):
return "DATE"
-class PG2Time(sqltypes.Time):
- def get_col_spec(self):
- return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG1Time(sqltypes.Time):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return psycopg.TimeFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGTime(sqltypes.Time):
def get_col_spec(self):
return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
@@ -142,28 +94,55 @@ class PGBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
-pg2_colspecs = {
+class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
+ def __init__(self, item_type):
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+
+ def dialect_impl(self, dialect):
+ impl = self.__class__.__new__(self.__class__)
+ impl.__dict__.update(self.__dict__)
+ impl.item_type = self.item_type.dialect_impl(dialect)
+ return impl
+ def convert_bind_param(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, (list,tuple)):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_bind_param(item, dialect)
+ return [convert_item(item) for item in value]
+ def convert_result_value(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, list):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_result_value(item, dialect)
+ # Could specialcase when item_type.convert_result_value is the default identity func
+ return [convert_item(item) for item in value]
+ def get_col_spec(self):
+ return self.item_type.get_col_spec() + '[]'
+
+colspecs = {
sqltypes.Integer : PGInteger,
sqltypes.Smallinteger : PGSmallInteger,
sqltypes.Numeric : PGNumeric,
sqltypes.Float : PGFloat,
- sqltypes.DateTime : PG2DateTime,
- sqltypes.Date : PG2Date,
- sqltypes.Time : PG2Time,
+ sqltypes.DateTime : PGDateTime,
+ sqltypes.Date : PGDate,
+ sqltypes.Time : PGTime,
sqltypes.String : PGString,
sqltypes.Binary : PGBinary,
sqltypes.Boolean : PGBoolean,
sqltypes.TEXT : PGText,
sqltypes.CHAR: PGChar,
}
-pg1_colspecs = pg2_colspecs.copy()
-pg1_colspecs.update({
- sqltypes.DateTime : PG1DateTime,
- sqltypes.Date : PG1Date,
- sqltypes.Time : PG1Time
- })
-
-pg2_ischema_names = {
+
+ischema_names = {
'integer' : PGInteger,
'bigint' : PGBigInteger,
'smallint' : PGSmallInteger,
@@ -175,24 +154,17 @@ pg2_ischema_names = {
'real' : PGFloat,
'inet': PGInet,
'double precision' : PGFloat,
- 'timestamp' : PG2DateTime,
- 'timestamp with time zone' : PG2DateTime,
- 'timestamp without time zone' : PG2DateTime,
- 'time with time zone' : PG2Time,
- 'time without time zone' : PG2Time,
- 'date' : PG2Date,
- 'time': PG2Time,
+ 'timestamp' : PGDateTime,
+ 'timestamp with time zone' : PGDateTime,
+ 'timestamp without time zone' : PGDateTime,
+ 'time with time zone' : PGTime,
+ 'time without time zone' : PGTime,
+ 'date' : PGDate,
+ 'time': PGTime,
'bytea' : PGBinary,
'boolean' : PGBoolean,
'interval':PGInterval,
}
-pg1_ischema_names = pg2_ischema_names.copy()
-pg1_ischema_names.update({
- 'timestamp with time zone' : PG1DateTime,
- 'timestamp without time zone' : PG1DateTime,
- 'date' : PG1Date,
- 'time' : PG1Time
- })
def descriptor():
return {'name':'postgres',
@@ -206,11 +178,11 @@ def descriptor():
class PGExecutionContext(default.DefaultExecutionContext):
- def is_select(self):
- return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
-
+ def _is_server_side(self):
+ return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I)
+
def create_cursor(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c" + hex(random.randint(0, 65535))[2:]
@@ -219,7 +191,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return self.connection.connection.cursor()
def get_result_proxy(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
return base.BufferedRowResultProxy(self)
else:
return base.ResultProxy(self)
@@ -242,31 +214,18 @@ class PGDialect(ansisql.ANSIDialect):
ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
- if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
- self.version = 2
- else:
- self.version = 1
self.use_information_schema = use_information_schema
self.paramstyle = 'pyformat'
def dbapi(cls):
- try:
- import psycopg2 as psycopg
- except ImportError, e:
- try:
- import psycopg
- except ImportError, e2:
- raise e
+ import psycopg2 as psycopg
return psycopg
dbapi = classmethod(dbapi)
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
if opts.has_key('port'):
- if self.version == 2:
- opts['port'] = int(opts['port'])
- else:
- opts['port'] = str(opts['port'])
+ opts['port'] = int(opts['port'])
opts.update(url.query)
return ([], opts)
@@ -278,10 +237,7 @@ class PGDialect(ansisql.ANSIDialect):
return 63
def type_descriptor(self, typeobj):
- if self.version == 2:
- return sqltypes.adapt_type(typeobj, pg2_colspecs)
- else:
- return sqltypes.adapt_type(typeobj, pg1_colspecs)
+ return sqltypes.adapt_type(typeobj, colspecs)
def compiler(self, statement, bindparams, **kwargs):
return PGCompiler(self, statement, bindparams, **kwargs)
@@ -292,8 +248,36 @@ class PGDialect(ansisql.ANSIDialect):
def schemadropper(self, *args, **kwargs):
return PGSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, connection, **kwargs):
- return PGDefaultRunner(connection, **kwargs)
+ def do_begin_twophase(self, connection, xid):
+ self.do_begin(connection.connection)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
+ # Must find out a way how to make the dbapi not open a transaction.
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
+ return [row[0] for row in resultset]
+
+ def defaultrunner(self, context, **kwargs):
+ return PGDefaultRunner(context, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
@@ -351,14 +335,9 @@ class PGDialect(ansisql.ANSIDialect):
else:
return False
- def reflecttable(self, connection, table):
- if self.version == 2:
- ischema_names = pg2_ischema_names
- else:
- ischema_names = pg1_ischema_names
-
+ def reflecttable(self, connection, table, include_columns):
if self.use_information_schema:
- ischema.reflecttable(connection, table, ischema_names)
+ ischema.reflecttable(connection, table, include_columns, ischema_names)
else:
preparer = self.identifier_preparer
if table.schema is not None:
@@ -387,7 +366,7 @@ class PGDialect(ansisql.ANSIDialect):
ORDER BY a.attnum
""" % schema_where_clause
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
+ s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
c = connection.execute(s, table_name=table.name,
schema=table.schema)
rows = c.fetchall()
@@ -398,9 +377,13 @@ class PGDialect(ansisql.ANSIDialect):
domains = self._load_domains(connection)
for name, format_type, default, notnull, attnum, table_oid in rows:
+ if include_columns and name not in include_columns:
+ continue
+
## strip (30) from character varying(30)
- attype = re.search('([^\(]+)', format_type).group(1)
+ attype = re.search('([^\([]+)', format_type).group(1)
nullable = not notnull
+ is_array = format_type.endswith('[]')
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
@@ -453,6 +436,8 @@ class PGDialect(ansisql.ANSIDialect):
if coltype:
coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = PGArray(coltype)
else:
warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
coltype = sqltypes.NULLTYPE
@@ -517,7 +502,6 @@ class PGDialect(ansisql.ANSIDialect):
table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
def _load_domains(self, connection):
-
## Load data types for domains:
SQL_DOMAINS = """
SELECT t.typname as "name",
@@ -554,49 +538,46 @@ class PGDialect(ansisql.ANSIDialect):
-
class PGCompiler(ansisql.ANSICompiler):
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : '%%'
+ }
+ )
- def visit_insert_sequence(self, column, sequence, parameters):
- """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures
- that the column is present in the generated column list"""
- parameters.setdefault(column.key, None)
+ def uses_sequences_for_inserts(self):
+ return True
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT ALL"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def visit_select_precolumns(self, select):
- if select.distinct:
- if type(select.distinct) == bool:
+ def get_select_precolumns(self, select):
+ if select._distinct:
+ if type(select._distinct) == bool:
return "DISTINCT "
- if type(select.distinct) == list:
+ if type(select._distinct) == list:
dist_set = "DISTINCT ON ("
- for col in select.distinct:
+ for col in select._distinct:
dist_set += self.strings[col] + ", "
dist_set = dist_set[:-2] + ") "
return dist_set
- return "DISTINCT ON (" + str(select.distinct) + ") "
+ return "DISTINCT ON (" + str(select._distinct) + ") "
else:
return ""
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- elif binary.operator == '%':
- return '%%'
+ def for_update_clause(self, select):
+ if select.for_update == 'nowait':
+ return " FOR UPDATE NOWAIT"
else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
+ return super(PGCompiler, self).for_update_clause(select)
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -617,13 +598,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class PGSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
@@ -632,7 +613,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
- return self.connection.execute_text("select %s" % column.default.arg).scalar()
+ return self.connection.execute("select %s" % column.default.arg).scalar()
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
@@ -641,7 +622,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.connection.execute_text(exc).scalar()
+ return self.connection.execute(exc).scalar()
return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 816b1b76a..725ea23e2 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -5,9 +5,9 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, types, re
+import re
-from sqlalchemy import sql, engine, schema, ansisql, exceptions, pool, PassiveDefault
+from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault
import sqlalchemy.engine.default as default
import sqlalchemy.types as sqltypes
import datetime,time, warnings
@@ -126,6 +126,7 @@ colspecs = {
pragma_names = {
'INTEGER' : SLInteger,
+ 'INT' : SLInteger,
'SMALLINT' : SLSmallInteger,
'VARCHAR' : SLString,
'CHAR' : SLChar,
@@ -150,8 +151,9 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
- super(SQLiteExecutionContext, self).post_exec()
+
+ def is_select(self):
+ return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None
class SQLiteDialect(ansisql.ANSIDialect):
@@ -233,7 +235,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
return (row is not None)
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {})
found_table = False
while True:
@@ -244,6 +246,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
found_table = True
(name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5])
name = re.sub(r'^\"|\"$', '', name)
+ if include_columns and name not in include_columns:
+ continue
match = re.match(r'(\w+)(\(.*?\))?', type)
if match:
coltype = match.group(1)
@@ -253,7 +257,12 @@ class SQLiteDialect(ansisql.ANSIDialect):
args = ''
#print "coltype: " + repr(coltype) + " args: " + repr(args)
- coltype = pragma_names.get(coltype, SLString)
+ try:
+ coltype = pragma_names[coltype]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
+ coltype = sqltypes.NULLTYPE
+
if args is not None:
args = re.findall(r'(\d+)', args)
#print "args! " +repr(args)
@@ -318,21 +327,21 @@ class SQLiteDialect(ansisql.ANSIDialect):
class SQLiteCompiler(ansisql.ANSICompiler):
def visit_cast(self, cast):
if self.dialect.supports_cast:
- super(SQLiteCompiler, self).visit_cast(cast)
+ return super(SQLiteCompiler, self).visit_cast(cast)
else:
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
- self.strings[cast] = self.strings[cast.clause]
+ return self.process(cast.clause)
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
else:
text += " OFFSET 0"
return text
@@ -341,12 +350,6 @@ class SQLiteCompiler(ansisql.ANSICompiler):
# sqlite has no "FOR UPDATE" AFAICT
return ''
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):