summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-08-18 21:37:48 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-08-18 21:37:48 +0000
commit7c6c1b99c2de00829b6f34ffba7e3bb689d34198 (patch)
treeacd6f8dc84cea86fc58b195a5f1068cbe020e955
parent534cf5fdbd05e2049ab9feceabf3926a5ab6380c (diff)
downloadsqlalchemy-7c6c1b99c2de00829b6f34ffba7e3bb689d34198.tar.gz
1. Module layout. sql.py and related move into a package called "sql".
2. compiler names changed to be less verbose, unused classes removed. 3. Methods on Dialect which return compilers, schema generators, identifier preparers have changed to direct class references, typically on the Dialect class itself or optionally as attributes on an individual Dialect instance if conditional behavior is needed. This takes away the need for Dialect subclasses to know how to instantiate these objects, and also reduces method overhead by one call for each one. 4. as a result of 3., some internal signatures have changed for things like compiler() (now statement_compiler()), preparer(), etc., mostly in that the dialect needs to be passed explicitly as the first argument (since they are just class references now). The compiler() method on Engine and Connection is now also named statement_compiler(), but as before does not take the dialect as an argument. 5. changed _process_row function on RowProxy to be a class reference, cuts out 50K method calls from insertspeed.py
-rw-r--r--lib/sqlalchemy/databases/access.py8
-rw-r--r--lib/sqlalchemy/databases/firebird.py42
-rw-r--r--lib/sqlalchemy/databases/informix.py41
-rw-r--r--lib/sqlalchemy/databases/mssql.py46
-rw-r--r--lib/sqlalchemy/databases/mysql.py36
-rw-r--r--lib/sqlalchemy/databases/oracle.py56
-rw-r--r--lib/sqlalchemy/databases/postgres.py48
-rw-r--r--lib/sqlalchemy/databases/sqlite.py41
-rw-r--r--lib/sqlalchemy/engine/base.py132
-rw-r--r--lib/sqlalchemy/engine/default.py18
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py6
-rw-r--r--lib/sqlalchemy/orm/interfaces.py3
-rw-r--r--lib/sqlalchemy/orm/mapper.py13
-rw-r--r--lib/sqlalchemy/orm/properties.py3
-rw-r--r--lib/sqlalchemy/orm/query.py8
-rw-r--r--lib/sqlalchemy/orm/strategies.py8
-rw-r--r--lib/sqlalchemy/orm/sync.py6
-rw-r--r--lib/sqlalchemy/orm/util.py8
-rw-r--r--lib/sqlalchemy/schema.py31
-rw-r--r--lib/sqlalchemy/sql/__init__.py3
-rw-r--r--lib/sqlalchemy/sql/compiler.py (renamed from lib/sqlalchemy/ansisql.py)110
-rw-r--r--lib/sqlalchemy/sql/expression.py (renamed from lib/sqlalchemy/sql.py)215
-rw-r--r--lib/sqlalchemy/sql/operators.py (renamed from lib/sqlalchemy/operators.py)0
-rw-r--r--lib/sqlalchemy/sql/util.py (renamed from lib/sqlalchemy/sql_util.py)107
-rw-r--r--lib/sqlalchemy/sql/visitors.py87
-rw-r--r--test/dialect/mysql.py6
-rw-r--r--test/engine/reflection.py3
-rw-r--r--test/orm/dynamic.py1
-rw-r--r--test/orm/query.py9
-rw-r--r--test/sql/constraints.py2
-rw-r--r--test/sql/generative.py3
-rw-r--r--test/sql/labels.py2
-rw-r--r--test/sql/query.py9
-rw-r--r--test/sql/quote.py8
-rw-r--r--test/sql/testtypes.py2
-rw-r--r--test/testlib/testing.py3
36 files changed, 517 insertions, 607 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py
index 6bf8b96e9..4aa773239 100644
--- a/lib/sqlalchemy/databases/access.py
+++ b/lib/sqlalchemy/databases/access.py
@@ -347,7 +347,7 @@ class AccessDialect(ansisql.ANSIDialect):
return names
-class AccessCompiler(ansisql.ANSICompiler):
+class AccessCompiler(compiler.DefaultCompiler):
def visit_select_precolumns(self, select):
"""Access puts TOP, it's version of LIMIT here """
s = select.distinct and "DISTINCT " or ""
@@ -387,7 +387,7 @@ class AccessCompiler(ansisql.ANSICompiler):
return ''
-class AccessSchemaGenerator(ansisql.ANSISchemaGenerator):
+class AccessSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -410,7 +410,7 @@ class AccessSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
-class AccessSchemaDropper(ansisql.ANSISchemaDropper):
+class AccessSchemaDropper(compiler.SchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name))
self.execute()
@@ -418,7 +418,7 @@ class AccessSchemaDropper(ansisql.ANSISchemaDropper):
class AccessDefaultRunner(ansisql.ANSIDefaultRunner):
pass
-class AccessIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class AccessIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index 307fceb48..9cccb53e8 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -7,9 +7,10 @@
import warnings
-from sqlalchemy import util, sql, schema, ansisql, exceptions
-import sqlalchemy.engine.default as default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy.sql import compiler
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
_initialized_kb = False
@@ -99,9 +100,9 @@ class FBExecutionContext(default.DefaultExecutionContext):
return True
-class FBDialect(ansisql.ANSIDialect):
+class FBDialect(default.DefaultDialect):
def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ default.DefaultDialect.__init__(self, **kwargs)
self.type_conv = type_conv
self.concurrency_level= concurrency_level
@@ -135,21 +136,6 @@ class FBDialect(ansisql.ANSIDialect):
def supports_sane_rowcount(self):
return False
- def compiler(self, statement, bindparams, **kwargs):
- return FBCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return FBSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return FBSchemaDropper(self, *args, **kwargs)
-
- def defaultrunner(self, connection):
- return FBDefaultRunner(connection)
-
- def preparer(self):
- return FBIdentifierPreparer(self)
-
def max_identifier_length(self):
return 31
@@ -307,7 +293,7 @@ class FBDialect(ansisql.ANSIDialect):
connection.commit(True)
-class FBCompiler(ansisql.ANSICompiler):
+class FBCompiler(compiler.DefaultCompiler):
"""Firebird specific idiosincrasies"""
def visit_alias(self, alias, asfrom=False, **kwargs):
@@ -346,7 +332,7 @@ class FBCompiler(ansisql.ANSICompiler):
return ""
-class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
+class FBSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -365,13 +351,13 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
self.execute()
-class FBSchemaDropper(ansisql.ANSISchemaDropper):
+class FBSchemaDropper(compiler.SchemaDropper):
def visit_sequence(self, sequence):
self.append("DROP GENERATOR %s" % sequence.name)
self.execute()
-class FBDefaultRunner(ansisql.ANSIDefaultRunner):
+class FBDefaultRunner(base.DefaultRunner):
def exec_default_sql(self, default):
c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection)
return self.connection.execute_compiled(c).scalar()
@@ -421,7 +407,7 @@ RESERVED_WORDS = util.Set(
"whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
-class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class FBIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True)
@@ -430,3 +416,9 @@ class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
dialect = FBDialect
+dialect.statement_compiler = FBCompiler
+dialect.schemagenerator = FBSchemaGenerator
+dialect.schemadropper = FBSchemaDropper
+dialect.defaultrunner = FBDefaultRunner
+dialect.preparer = FBIdentifierPreparer
+
diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py
index 21ecf1538..ceb56903a 100644
--- a/lib/sqlalchemy/databases/informix.py
+++ b/lib/sqlalchemy/databases/informix.py
@@ -7,9 +7,10 @@
import datetime, warnings
-from sqlalchemy import sql, schema, ansisql, exceptions, pool
-import sqlalchemy.engine.default as default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import sql, schema, exceptions, pool
+from sqlalchemy.sql import compiler
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
# for offset
@@ -203,11 +204,11 @@ class InfoExecutionContext(default.DefaultExecutionContext):
def create_cursor( self ):
return informix_cursor( self.connection.connection )
-class InfoDialect(ansisql.ANSIDialect):
+class InfoDialect(default.DefaultDialect):
def __init__(self, use_ansi=True,**kwargs):
self.use_ansi = use_ansi
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ default.DefaultDialect.__init__(self, **kwargs)
self.paramstyle = 'qmark'
def dbapi(cls):
@@ -252,18 +253,6 @@ class InfoDialect(ansisql.ANSIDialect):
def oid_column_name(self,column):
return "rowid"
- def preparer(self):
- return InfoIdentifierPreparer(self)
-
- def compiler(self, statement, bindparams, **kwargs):
- return InfoCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return InfoSchemaGenerator( self , *args, **kwargs)
-
- def schemadropper(self, *args, **params):
- return InfoSchemaDroper( self , *args , **params)
-
def table_names(self, connection, schema):
s = "select tabname from systables"
return [row[0] for row in connection.execute(s)]
@@ -376,14 +365,14 @@ class InfoDialect(ansisql.ANSIDialect):
for cons_name, cons_type, local_column in rows:
table.primary_key.add( table.c[local_column] )
-class InfoCompiler(ansisql.ANSICompiler):
+class InfoCompiler(compiler.DefaultCompiler):
"""Info compiler modifies the lexical structure of Select statements to work under
non-ANSI configured Oracle databases, if the use_ansi flag is False."""
def __init__(self, dialect, statement, parameters=None, **kwargs):
self.limit = 0
self.offset = 0
- ansisql.ANSICompiler.__init__( self , dialect , statement , parameters , **kwargs )
+ compiler.DefaultCompiler.__init__( self , dialect , statement , parameters , **kwargs )
def default_from(self):
return " from systables where tabname = 'systables' "
@@ -416,7 +405,7 @@ class InfoCompiler(ansisql.ANSICompiler):
if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
select.append_column( c )
- return ansisql.ANSICompiler.visit_select(self, select)
+ return compiler.DefaultCompiler.visit_select(self, select)
def limit_clause(self, select):
return ""
@@ -437,7 +426,7 @@ class InfoCompiler(ansisql.ANSICompiler):
elif func.name.lower() in ( 'current_timestamp' , 'now' ):
return "CURRENT YEAR TO SECOND"
else:
- return ansisql.ANSICompiler.visit_function( self , func )
+ return compiler.DefaultCompiler.visit_function( self , func )
def visit_clauselist(self, list):
try:
@@ -446,7 +435,7 @@ class InfoCompiler(ansisql.ANSICompiler):
li = [ c for c in list.clauses ]
return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
-class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
+class InfoSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, first_pk=False):
colspec = self.preparer.format_column(column)
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
@@ -507,7 +496,7 @@ class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
return
super(InfoSchemaGenerator, self).visit_index(index)
-class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class InfoIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
@@ -517,10 +506,14 @@ class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def _requires_quotes(self, value):
return False
-class InfoSchemaDroper(ansisql.ANSISchemaDropper):
+class InfoSchemaDroper(compiler.SchemaDropper):
def drop_foreignkey(self, constraint):
if constraint.name is not None:
super( InfoSchemaDroper , self ).drop_foreignkey( constraint )
dialect = InfoDialect
poolclass = pool.SingletonThreadPool
+dialect.statement_compiler = InfoCompiler
+dialect.schemagenerator = InfoSchemaGenerator
+dialect.schemadropper = InfoSchemaDropper
+dialect.preparer = InfoIdentifierPreparer
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 619e072d9..0caccca95 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -39,10 +39,10 @@ Known issues / TODO:
import datetime, random, warnings, re
-from sqlalchemy import sql, schema, ansisql, exceptions
-import sqlalchemy.types as sqltypes
-from sqlalchemy.engine import default
-import operator, sys
+from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
class MSNumeric(sqltypes.Numeric):
def result_processor(self, dialect):
@@ -366,7 +366,7 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
super(MSSQLExecutionContext_pyodbc, self).post_exec()
-class MSSQLDialect(ansisql.ANSIDialect):
+class MSSQLDialect(default.DefaultDialect):
colspecs = {
sqltypes.Unicode : MSNVarchar,
sqltypes.Integer : MSInteger,
@@ -476,21 +476,6 @@ class MSSQLDialect(ansisql.ANSIDialect):
def supports_sane_rowcount(self):
raise NotImplementedError()
- def compiler(self, statement, bindparams, **kwargs):
- return MSSQLCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return MSSQLSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return MSSQLSchemaDropper(self, *args, **kwargs)
-
- def defaultrunner(self, connection, **kwargs):
- return MSSQLDefaultRunner(connection, **kwargs)
-
- def preparer(self):
- return MSSQLIdentifierPreparer(self)
-
def get_default_schema_name(self, connection):
return self.schema_name
@@ -878,7 +863,7 @@ dialect_mapping = {
}
-class MSSQLCompiler(ansisql.ANSICompiler):
+class MSSQLCompiler(compiler.DefaultCompiler):
def __init__(self, dialect, statement, parameters, **kwargs):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
@@ -931,13 +916,13 @@ class MSSQLCompiler(ansisql.ANSICompiler):
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
- return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+ if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
+ return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
else:
return super(MSSQLCompiler, self).visit_binary(binary)
def label_select_column(self, select, column):
- if isinstance(column, sql._Function):
+ if isinstance(column, expression._Function):
return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
else:
return super(MSSQLCompiler, self).label_select_column(select, column)
@@ -963,7 +948,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
return ""
-class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
+class MSSQLSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -986,7 +971,7 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
-class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
+class MSSQLSchemaDropper(compiler.SchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s.%s" % (
self.preparer.quote_identifier(index.table.name),
@@ -995,11 +980,11 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
self.execute()
-class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
+class MSSQLDefaultRunner(base.DefaultRunner):
# TODO: does ms-sql have standalone sequences ?
pass
-class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
@@ -1012,6 +997,11 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
return value
dialect = MSSQLDialect
+dialect.statement_compiler = MSSQLCompiler
+dialect.schemagenerator = MSSQLSchemaGenerator
+dialect.schemadropper = MSSQLSchemaDropper
+dialect.preparer = MSSQLIdentifierPreparer
+dialect.defaultrunner = MSSQLDefaultRunner
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index d5fd3b6c5..41c6ec70f 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -126,11 +126,12 @@ information affecting MySQL in SQLAlchemy.
import re, datetime, inspect, warnings, sys
from array import array as _array
-from sqlalchemy import ansisql, exceptions, logging, schema, sql, util
-from sqlalchemy import operators as sql_operators
+from sqlalchemy import exceptions, logging, schema, sql, util
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy.sql import compiler
from sqlalchemy.engine import base as engine_base, default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import types as sqltypes
__all__ = (
@@ -1328,13 +1329,17 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
return AUTOCOMMIT_RE.match(self.statement)
-class MySQLDialect(ansisql.ANSIDialect):
+class MySQLDialect(default.DefaultDialect):
"""Details of the MySQL dialect. Not used directly in application code."""
def __init__(self, use_ansiquotes=False, **kwargs):
self.use_ansiquotes = use_ansiquotes
kwargs.setdefault('default_paramstyle', 'format')
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ if self.use_ansiquotes:
+ self.preparer = MySQLANSIIdentifierPreparer
+ else:
+ self.preparer = MySQLIdentifierPreparer
+ default.DefaultDialect.__init__(self, **kwargs)
def dbapi(cls):
import MySQLdb as mysql
@@ -1393,7 +1398,7 @@ class MySQLDialect(ansisql.ANSIDialect):
return True
def compiler(self, statement, bindparams, **kwargs):
- return MySQLCompiler(self, statement, bindparams, **kwargs)
+ return MySQLCompiler(statement, bindparams, dialect=self, **kwargs)
def schemagenerator(self, *args, **kwargs):
return MySQLSchemaGenerator(self, *args, **kwargs)
@@ -1401,12 +1406,6 @@ class MySQLDialect(ansisql.ANSIDialect):
def schemadropper(self, *args, **kwargs):
return MySQLSchemaDropper(self, *args, **kwargs)
- def preparer(self):
- if self.use_ansiquotes:
- return MySQLANSIIdentifierPreparer(self)
- else:
- return MySQLIdentifierPreparer(self)
-
def do_executemany(self, cursor, statement, parameters,
context=None, **kwargs):
rowcount = cursor.executemany(statement, parameters)
@@ -1733,8 +1732,8 @@ class _MySQLPythonRowProxy(object):
return item
-class MySQLCompiler(ansisql.ANSICompiler):
- operators = ansisql.ANSICompiler.operators.copy()
+class MySQLCompiler(compiler.DefaultCompiler):
+ operators = compiler.DefaultCompiler.operators.copy()
operators.update(
{
sql_operators.concat_op: \
@@ -1783,7 +1782,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
# In older versions, the indexes must be created explicitly or the
# creation of foreign key constraints fails."
-class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
+class MySQLSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, override_pk=False,
first_pk=False):
"""Builds column DDL."""
@@ -1827,7 +1826,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
return ' '.join(table_opts)
-class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
+class MySQLSchemaDropper(compiler.SchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s ON %s" %
(self.preparer.format_index(index),
@@ -2368,7 +2367,7 @@ class MySQLSchemaReflector(object):
MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector)
-class _MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
"""MySQL-specific schema identifier configuration."""
def __init__(self, dialect, **kw):
@@ -2433,3 +2432,6 @@ def _re_compile(regex):
return re.compile(regex, re.I | re.UNICODE)
dialect = MySQLDialect
+dialect.statement_compiler = MySQLCompiler
+dialect.schemagenerator = MySQLSchemaGenerator
+dialect.schemadropper = MySQLSchemaDropper
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index a35db1982..2d8f2940f 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -5,11 +5,13 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, warnings, operator, random
+import re, warnings, random
-from sqlalchemy import util, sql, schema, ansisql, exceptions, logging
+from sqlalchemy import util, sql, schema, exceptions, logging
from sqlalchemy.engine import default, base
-import sqlalchemy.types as sqltypes
+from sqlalchemy.sql import compiler, visitors
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
import datetime
@@ -229,9 +231,9 @@ class OracleExecutionContext(default.DefaultExecutionContext):
return base.ResultProxy(self)
-class OracleDialect(ansisql.ANSIDialect):
+class OracleDialect(default.DefaultDialect):
def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs):
- ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs)
+ default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs)
self.use_ansi = use_ansi
self.threaded = threaded
self.allow_twophase = allow_twophase
@@ -333,21 +335,6 @@ class OracleDialect(ansisql.ANSIDialect):
def create_execution_context(self, *args, **kwargs):
return OracleExecutionContext(self, *args, **kwargs)
- def compiler(self, statement, bindparams, **kwargs):
- return OracleCompiler(self, statement, bindparams, **kwargs)
-
- def preparer(self):
- return OracleIdentifierPreparer(self)
-
- def schemagenerator(self, *args, **kwargs):
- return OracleSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return OracleSchemaDropper(self, *args, **kwargs)
-
- def defaultrunner(self, connection, **kwargs):
- return OracleDefaultRunner(connection, **kwargs)
-
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':self._denormalize_name(table_name)})
return bool( cursor.fetchone() is not None )
@@ -560,16 +547,16 @@ class _OuterJoinColumn(sql.ClauseElement):
def __init__(self, column):
self.column = column
-class OracleCompiler(ansisql.ANSICompiler):
+class OracleCompiler(compiler.DefaultCompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
the use_ansi flag is False.
"""
- operators = ansisql.ANSICompiler.operators.copy()
+ operators = compiler.DefaultCompiler.operators.copy()
operators.update(
{
- operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+ sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
}
)
@@ -590,13 +577,13 @@ class OracleCompiler(ansisql.ANSICompiler):
def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+ return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
(where, parentjoin) = self.__wheres.get(join, (None, None))
- class VisitOn(sql.ClauseVisitor):
+ class VisitOn(visitors.ClauseVisitor):
def visit_binary(s, binary):
- if binary.operator == operator.eq:
+ if binary.operator == sql_operators.eq:
if binary.left.table is join.right:
binary.left = _OuterJoinColumn(binary.left)
elif binary.right.table is join.right:
@@ -640,7 +627,7 @@ class OracleCompiler(ansisql.ANSICompiler):
for c in insert.table.primary_key:
if c.key not in self.parameters:
self.parameters[c.key] = None
- return ansisql.ANSICompiler.visit_insert(self, insert)
+ return compiler.DefaultCompiler.visit_insert(self, insert)
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
@@ -672,7 +659,7 @@ class OracleCompiler(ansisql.ANSICompiler):
limitselect.append_whereclause("ora_rn<=%d" % select._limit)
return self.process(limitselect, **kwargs)
else:
- return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
+ return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""
@@ -684,7 +671,7 @@ class OracleCompiler(ansisql.ANSICompiler):
return super(OracleCompiler, self).for_update_clause(select)
-class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
+class OracleSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -701,13 +688,13 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class OracleSchemaDropper(ansisql.ANSISchemaDropper):
+class OracleSchemaDropper(compiler.SchemaDropper):
def visit_sequence(self, sequence):
if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
+class OracleDefaultRunner(base.DefaultRunner):
def exec_default_sql(self, default):
c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
return self.connection.execute(c).scalar()
@@ -715,10 +702,15 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def visit_sequence(self, seq):
return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
-class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
def format_savepoint(self, savepoint):
name = re.sub(r'^_+', '', savepoint.ident)
return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
dialect = OracleDialect
+dialect.statement_compiler = OracleCompiler
+dialect.schemagenerator = OracleSchemaGenerator
+dialect.schemadropper = OracleSchemaDropper
+dialect.preparer = OracleIdentifierPreparer
+dialect.defaultrunner = OracleDefaultRunner
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 74a3ef13f..29d84ad4d 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,11 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, random, warnings, operator
+import re, random, warnings
-from sqlalchemy import sql, schema, ansisql, exceptions, util
+from sqlalchemy import sql, schema, exceptions, util
from sqlalchemy.engine import base, default
-import sqlalchemy.types as sqltypes
+from sqlalchemy.sql import compiler
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
class PGInet(sqltypes.TypeEngine):
@@ -220,9 +222,9 @@ class PGExecutionContext(default.DefaultExecutionContext):
self._last_inserted_ids = [v for v in row]
super(PGExecutionContext, self).post_exec()
-class PGDialect(ansisql.ANSIDialect):
+class PGDialect(default.DefaultDialect):
def __init__(self, use_oids=False, server_side_cursors=False, **kwargs):
- ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
+ default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
self.paramstyle = 'pyformat'
@@ -249,15 +251,6 @@ class PGDialect(ansisql.ANSIDialect):
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
- def compiler(self, statement, bindparams, **kwargs):
- return PGCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return PGSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return PGSchemaDropper(self, *args, **kwargs)
-
def do_begin_twophase(self, connection, xid):
self.do_begin(connection.connection)
@@ -286,12 +279,6 @@ class PGDialect(ansisql.ANSIDialect):
resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
return [row[0] for row in resultset]
- def defaultrunner(self, context, **kwargs):
- return PGDefaultRunner(context, **kwargs)
-
- def preparer(self):
- return PGIdentifierPreparer(self)
-
def get_default_schema_name(self, connection):
if not hasattr(self, '_default_schema_name'):
self._default_schema_name = connection.scalar("select current_schema()", None)
@@ -556,11 +543,11 @@ class PGDialect(ansisql.ANSIDialect):
-class PGCompiler(ansisql.ANSICompiler):
- operators = ansisql.ANSICompiler.operators.copy()
+class PGCompiler(compiler.DefaultCompiler):
+ operators = compiler.DefaultCompiler.operators.copy()
operators.update(
{
- operator.mod : '%%'
+ sql_operators.mod : '%%'
}
)
@@ -597,7 +584,7 @@ class PGCompiler(ansisql.ANSICompiler):
else:
return super(PGCompiler, self).for_update_clause(select)
-class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
+class PGSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
@@ -620,13 +607,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class PGSchemaDropper(ansisql.ANSISchemaDropper):
+class PGSchemaDropper(compiler.SchemaDropper):
def visit_sequence(self, sequence):
if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class PGDefaultRunner(ansisql.ANSIDefaultRunner):
+class PGDefaultRunner(base.DefaultRunner):
def get_column_default(self, column, isinsert=True):
if column.primary_key:
# passive defaults on primary keys have to be overridden
@@ -642,7 +629,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
return self.connection.execute(exc).scalar()
- return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
+ return super(PGDefaultRunner, self).get_column_default(column)
def visit_sequence(self, seq):
if not seq.optional:
@@ -650,7 +637,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
else:
return None
-class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
def _fold_identifier_case(self, value):
return value.lower()
@@ -660,3 +647,8 @@ class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
return value
dialect = PGDialect
+dialect.statement_compiler = PGCompiler
+dialect.schemagenerator = PGSchemaGenerator
+dialect.schemadropper = PGSchemaDropper
+dialect.preparer = PGIdentifierPreparer
+dialect.defaultrunner = PGDefaultRunner
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index d96422236..c2aced4d0 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -7,11 +7,12 @@
import re
-from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault
-import sqlalchemy.engine.default as default
+from sqlalchemy import schema, exceptions, pool, PassiveDefault
+from sqlalchemy.engine import default
import sqlalchemy.types as sqltypes
import datetime,time, warnings
import sqlalchemy.util as util
+from sqlalchemy.sql import compiler
SELECT_REGEXP = re.compile(r'\s*(?:SELECT|PRAGMA)', re.I | re.UNICODE)
@@ -172,10 +173,10 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
def is_select(self):
return SELECT_REGEXP.match(self.statement)
-class SQLiteDialect(ansisql.ANSIDialect):
+class SQLiteDialect(default.DefaultDialect):
def __init__(self, **kwargs):
- ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
+ default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs)
def vers(num):
return tuple([int(x) for x in num.split('.')])
if self.dbapi is not None:
@@ -195,24 +196,12 @@ class SQLiteDialect(ansisql.ANSIDialect):
return sqlite
dbapi = classmethod(dbapi)
- def compiler(self, statement, bindparams, **kwargs):
- return SQLiteCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return SQLiteSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return SQLiteSchemaDropper(self, *args, **kwargs)
-
def server_version_info(self, connection):
return self.dbapi.sqlite_version_info
def supports_alter(self):
return False
- def preparer(self):
- return SQLiteIdentifierPreparer(self)
-
def create_connect_args(self, url):
filename = url.database or ':memory:'
@@ -255,7 +244,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
return (row is not None)
def reflecttable(self, connection, table, include_columns):
- c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {})
+ c = connection.execute("PRAGMA table_info(%s)" % self.identifier_preparer.format_table(table), {})
found_table = False
while True:
row = c.fetchone()
@@ -295,7 +284,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
if not found_table:
raise exceptions.NoSuchTableError(table.name)
- c = connection.execute("PRAGMA foreign_key_list(%s)" % self.preparer().format_table(table), {})
+ c = connection.execute("PRAGMA foreign_key_list(%s)" % self.identifier_preparer.format_table(table), {})
fks = {}
while True:
row = c.fetchone()
@@ -324,7 +313,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
for name, value in fks.iteritems():
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1]))
# check for UNIQUE indexes
- c = connection.execute("PRAGMA index_list(%s)" % self.preparer().format_table(table), {})
+ c = connection.execute("PRAGMA index_list(%s)" % self.identifier_preparer.format_table(table), {})
unique_indexes = []
while True:
row = c.fetchone()
@@ -343,7 +332,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
cols.append(row[2])
col = table.columns[row[2]]
-class SQLiteCompiler(ansisql.ANSICompiler):
+class SQLiteCompiler(compiler.DefaultCompiler):
def visit_cast(self, cast):
if self.dialect.supports_cast:
return super(SQLiteCompiler, self).visit_cast(cast)
@@ -369,7 +358,8 @@ class SQLiteCompiler(ansisql.ANSICompiler):
# sqlite has no "FOR UPDATE" AFAICT
return ''
-class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
+
+class SQLiteSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -391,12 +381,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
# else:
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
-class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
+class SQLiteSchemaDropper(compiler.SchemaDropper):
pass
-class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
dialect = SQLiteDialect
dialect.poolclass = pool.SingletonThreadPool
+dialect.statement_compiler = SQLiteCompiler
+dialect.schemagenerator = SQLiteSchemaGenerator
+dialect.schemadropper = SQLiteSchemaDropper
+dialect.preparer = SQLiteIdentifierPreparer
+
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index faeb00cc9..553c8df84 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -8,7 +8,8 @@
higher-level statement-construction, connection-management,
execution and result contexts."""
-from sqlalchemy import exceptions, sql, schema, util, types, logging
+from sqlalchemy import exceptions, schema, util, types, logging
+from sqlalchemy.sql import expression, visitors
import StringIO, sys
@@ -35,6 +36,22 @@ class Dialect(object):
encoding
type of encoding to use for unicode, usually defaults to 'utf-8'
+
+ schemagenerator
+ a [sqlalchemy.schema#SchemaVisitor] class which generates schemas.
+
+ schemadropper
+ a [sqlalchemy.schema#SchemaVisitor] class which drops schemas.
+
+ defaultrunner
+ a [sqlalchemy.schema#SchemaVisitor] class which executes defaults.
+
+ statement_compiler
+ a [sqlalchemy.engine.base#Compiled] class used to compile SQL statements
+
+ preparer
+ a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote
+ identifiers.
"""
def create_connect_args(self, url):
@@ -105,48 +122,6 @@ class Dialect(object):
raise NotImplementedError()
- def schemagenerator(self, connection, **kwargs):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas.
-
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
-
- `schemagenerator()` is called via the `create()` method on Table,
- Index, and others.
- """
-
- raise NotImplementedError()
-
- def schemadropper(self, connection, **kwargs):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas.
-
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
-
- `schemadropper()` is called via the `drop()` method on Table,
- Index, and others.
- """
-
- raise NotImplementedError()
-
- def defaultrunner(self, execution_context):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
-
- execution_context
- a [sqlalchemy.engine#ExecutionContext] to use for statement execution
-
- """
-
- raise NotImplementedError()
-
- def compiler(self, statement, parameters):
- """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters.
-
- The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler].
-
- """
-
- raise NotImplementedError()
def server_version_info(self, connection):
"""Return a tuple of the database's version number."""
@@ -266,16 +241,6 @@ class Dialect(object):
raise NotImplementedError()
-
- def compile(self, clauseelement, parameters=None):
- """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect.
-
- Returns [sqlalchemy.sql#Compiled]. A convenience method which
- flips around the compile() call on ``ClauseElement``.
- """
-
- return clauseelement.compile(dialect=self, parameters=parameters)
-
def is_disconnect(self, e):
"""Return True if the given DBAPI error indicates an invalid connection"""
@@ -304,7 +269,7 @@ class ExecutionContext(object):
DBAPI cursor procured from the connection
compiled
- if passed to constructor, sql.Compiled object being executed
+ if passed to constructor, sqlalchemy.engine.base.Compiled object being executed
statement
string version of the statement to be executed. Is either
@@ -439,6 +404,9 @@ class Compiled(object):
def __init__(self, dialect, statement, parameters, bind=None):
"""Construct a new ``Compiled`` object.
+ dialect
+ ``Dialect`` to compile against.
+
statement
``ClauseElement`` to be compiled.
@@ -724,8 +692,8 @@ class Connection(Connectable):
def scalar(self, object, *multiparams, **params):
return self.execute(object, *multiparams, **params).scalar()
- def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+ def statement_compiler(self, statement, parameters, **kwargs):
+ return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
@@ -822,9 +790,9 @@ class Connection(Connectable):
# poor man's multimethod/generic function thingy
executors = {
- sql._Function : _execute_function,
- sql.ClauseElement : _execute_clauseelement,
- sql.ClauseVisitor : _execute_compiled,
+ expression._Function : _execute_function,
+ expression.ClauseElement : _execute_clauseelement,
+ visitors.ClauseVisitor : _execute_compiled,
schema.SchemaItem:_execute_default,
str.__mro__[-2] : _execute_text
}
@@ -989,14 +957,14 @@ class Engine(Connectable):
connection.close()
def _func(self):
- return sql._FunctionGenerator(bind=self)
+ return expression._FunctionGenerator(bind=self)
func = property(_func)
def text(self, text, *args, **kwargs):
"""Return a sql.text() object for performing literal queries."""
- return sql.text(text, bind=self, *args, **kwargs)
+ return expression.text(text, bind=self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
if connection is None:
@@ -1004,7 +972,7 @@ class Engine(Connectable):
else:
conn = connection
try:
- visitorcallable(conn, **kwargs).traverse(element)
+ visitorcallable(self.dialect, conn, **kwargs).traverse(element)
finally:
if connection is None:
conn.close()
@@ -1057,8 +1025,8 @@ class Engine(Connectable):
connection = self.contextual_connect(close_with_result=True)
return connection._execute_compiled(compiled, multiparams, params)
- def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+ def statement_compiler(self, statement, parameters, **kwargs):
+ return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
@@ -1159,6 +1127,7 @@ class ResultProxy(object):
self.closed = False
self.cursor = context.cursor
self.__echo = logging.is_debug_enabled(context.engine.logger)
+ self._process_row = self._row_processor()
if context.is_select():
self._init_metadata()
self._rowcount = None
@@ -1222,7 +1191,7 @@ class ResultProxy(object):
rec = props[key]
elif isinstance(key, basestring) and key.lower() in props:
rec = props[key.lower()]
- elif isinstance(key, sql.ColumnElement):
+ elif isinstance(key, expression.ColumnElement):
label = context.column_labels.get(key._label, key.name).lower()
if label in props:
rec = props[label]
@@ -1320,21 +1289,21 @@ class ResultProxy(object):
return self.cursor.fetchmany(size)
def _fetchall_impl(self):
return self.cursor.fetchall()
+
+ def _row_processor(self):
+ return RowProxy
- def _process_row(self, row):
- return RowProxy(self, row)
-
def fetchall(self):
"""Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
- l = [self._process_row(row) for row in self._fetchall_impl()]
+ l = [self._process_row(self, row) for row in self._fetchall_impl()]
self.close()
return l
def fetchmany(self, size=None):
"""Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
- l = [self._process_row(row) for row in self._fetchmany_impl(size)]
+ l = [self._process_row(self, row) for row in self._fetchmany_impl(size)]
if len(l) == 0:
self.close()
return l
@@ -1343,7 +1312,7 @@ class ResultProxy(object):
"""Fetch one row, just like DBAPI ``cursor.fetchone()``."""
row = self._fetchone_impl()
if row is not None:
- return self._process_row(row)
+ return self._process_row(self, row)
else:
self.close()
return None
@@ -1353,7 +1322,7 @@ class ResultProxy(object):
row = self._fetchone_impl()
try:
if row is not None:
- return self._process_row(row)[0]
+ return self._process_row(self, row)[0]
else:
return None
finally:
@@ -1425,11 +1394,9 @@ class BufferedColumnResultProxy(ResultProxy):
def _get_col(self, row, key):
rec = self._key_cache[key]
return row[rec[2]]
-
- def _process_row(self, row):
- sup = super(BufferedColumnResultProxy, self)
- row = [sup._get_col(row, i) for i in xrange(len(row))]
- return RowProxy(self, row)
+
+ def _row_processor(self):
+ return BufferedColumnRow
def fetchall(self):
l = []
@@ -1523,6 +1490,11 @@ class RowProxy(object):
def __len__(self):
return len(self.__row)
+class BufferedColumnRow(RowProxy):
+ def __init__(self, parent, row):
+ row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))]
+ super(BufferedColumnRow, self).__init__(parent, row)
+
class SchemaIterator(schema.SchemaVisitor):
"""A visitor that can gather text into a buffer and execute the contents of the buffer."""
@@ -1590,11 +1562,11 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def exec_default_sql(self, default):
- c = sql.select([default.arg]).compile(bind=self.connection)
+ c = expression.select([default.arg]).compile(bind=self.connection)
return self.connection._execute_compiled(c).scalar()
def visit_column_onupdate(self, onupdate):
- if isinstance(onupdate.arg, sql.ClauseElement):
+ if isinstance(onupdate.arg, expression.ClauseElement):
return self.exec_default_sql(onupdate)
elif callable(onupdate.arg):
return onupdate.arg(self.context)
@@ -1602,7 +1574,7 @@ class DefaultRunner(schema.SchemaVisitor):
return onupdate.arg
def visit_column_default(self, default):
- if isinstance(default.arg, sql.ClauseElement):
+ if isinstance(default.arg, expression.ClauseElement):
return self.exec_default_sql(default)
elif callable(default.arg):
return default.arg(self.context)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index ccaf080e7..059395921 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -9,7 +9,7 @@
from sqlalchemy import schema, exceptions, sql, util
import re, random
from sqlalchemy.engine import base
-
+from sqlalchemy.sql import compiler, expression
AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
re.I | re.UNICODE)
@@ -18,6 +18,12 @@ SELECT_REGEXP = re.compile(r'\s*SELECT', re.I | re.UNICODE)
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
+ schemagenerator = compiler.SchemaGenerator
+ schemadropper = compiler.SchemaDropper
+ statement_compiler = compiler.DefaultCompiler
+ preparer = compiler.IdentifierPreparer
+ defaultrunner = base.DefaultRunner
+
def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
self.convert_unicode = convert_unicode
self.encoding = encoding
@@ -25,6 +31,7 @@ class DefaultDialect(base.Dialect):
self._ischema = None
self.dbapi = dbapi
self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
+ self.identifier_preparer = self.preparer(self)
def dbapi_type_map(self):
# most DBAPIs have problems with this (such as, psycocpg2 types
@@ -46,6 +53,7 @@ class DefaultDialect(base.Dialect):
typeobj = typeobj()
return typeobj
+
def supports_unicode_statements(self):
"""indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
return False
@@ -96,13 +104,13 @@ class DefaultDialect(base.Dialect):
return "_sa_%032x" % random.randint(0,2**128)
def do_savepoint(self, connection, name):
- connection.execute(sql.SavepointClause(name))
+ connection.execute(expression.SavepointClause(name))
def do_rollback_to_savepoint(self, connection, name):
- connection.execute(sql.RollbackToSavepointClause(name))
+ connection.execute(expression.RollbackToSavepointClause(name))
def do_release_savepoint(self, connection, name):
- connection.execute(sql.ReleaseSavepointClause(name))
+ connection.execute(expression.ReleaseSavepointClause(name))
def do_executemany(self, cursor, statement, parameters, **kwargs):
cursor.executemany(statement, parameters)
@@ -110,8 +118,6 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, context):
- return base.DefaultRunner(context)
def is_disconnect(self, e):
return False
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index 65d2ab3d2..258bddb4a 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -294,7 +294,7 @@ from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
from sqlalchemy.exceptions import *
-
+from sqlalchemy.sql import expression
_testsql = """
CREATE TABLE books (
@@ -415,7 +415,7 @@ def _selectable_name(selectable):
return x
def class_for_table(selectable, **mapper_kwargs):
- selectable = sql._selectable(selectable)
+ selectable = expression._selectable(selectable)
mapname = 'Mapped' + _selectable_name(selectable)
if isinstance(selectable, Table):
klass = TableClassType(mapname, (object,), {})
@@ -499,7 +499,7 @@ class SqlSoup:
def with_labels(self, item):
# TODO give meaningful aliases
- return self.map(sql._selectable(item).select(use_labels=True).alias('foo'))
+ return self.map(expression._selectable(item).select(use_labels=True).alias('foo'))
def join(self, *args, **kwargs):
j = join(*args, **kwargs)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index c54eee438..9000a8df5 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -6,6 +6,7 @@
from sqlalchemy import util, logging, sql
+from sqlalchemy.sql import expression
__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
'MapperProperty', 'PropComparator', 'StrategizedProperty',
@@ -363,7 +364,7 @@ class MapperProperty(object):
return operator(self.comparator, value)
-class PropComparator(sql.ColumnOperators):
+class PropComparator(expression.ColumnOperators):
"""defines comparison operations for MapperProperty objects"""
def expression_element(self):
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index b48368417..60d4526ec 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -6,7 +6,8 @@
import weakref, warnings, operator
from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy import sql_util as sqlutil
+from sqlalchemy.sql import expression
+from sqlalchemy.sql import util as sqlutil
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.util import ExtensionCarrier
from sqlalchemy.orm import sync
@@ -77,7 +78,7 @@ class Mapper(object):
raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
for table in (local_table, select_table):
- if table is not None and isinstance(table, sql._SelectBaseMixin):
+ if table is not None and isinstance(table, expression._SelectBaseMixin):
# some db's, noteably postgres, dont want to select from a select
# without an alias. also if we make our own alias internally, then
# the configured properties on the mapper are not matched against the alias
@@ -438,7 +439,7 @@ class Mapper(object):
# against the "mapped_table" of this mapper.
equivalent_columns = self._get_equivalent_columns()
- primary_key = sql.ColumnSet()
+ primary_key = expression.ColumnSet()
for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
c = self.mapped_table.corresponding_column(col, raiseerr=False)
@@ -644,9 +645,9 @@ class Mapper(object):
props = {}
if self.properties is not None:
for key, prop in self.properties.iteritems():
- if sql.is_column(prop):
+ if expression.is_column(prop):
props[key] = self.select_table.corresponding_column(prop)
- elif (isinstance(prop, list) and sql.is_column(prop[0])):
+ elif (isinstance(prop, list) and expression.is_column(prop[0])):
props[key] = [self.select_table.corresponding_column(c) for c in prop]
self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument)
@@ -768,7 +769,7 @@ class Mapper(object):
def _create_prop_from_column(self, column):
column = util.to_list(column)
- if not sql.is_column(column[0]):
+ if not expression.is_column(column[0]):
return None
mapped_column = []
for c in column:
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 670fcccc9..20cbcb235 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -11,7 +11,8 @@ operations. PropertyLoader also relies upon the dependency.py module
to handle flush-time dependency sorting and processing.
"""
-from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
+from sqlalchemy import sql, schema, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 44329468f..5cbe19ce2 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -4,7 +4,9 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import expression, visitors
from sqlalchemy.orm import mapper, object_mapper
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.interfaces import OperationContext, LoaderStack
@@ -312,7 +314,7 @@ class Query(object):
clause = self._from_obj[-1]
currenttables = [clause]
- class FindJoinedTables(sql.NoColumnVisitor):
+ class FindJoinedTables(visitors.NoColumnVisitor):
def visit_join(self, join):
currenttables.append(join.left)
currenttables.append(join.right)
@@ -836,7 +838,7 @@ class Query(object):
# if theres an order by, add those columns to the column list
# of the "rowcount" query we're going to make
if order_by:
- order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []]
+ order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
cf = sql_util.ColumnFinder()
for o in order_by:
cf.traverse(o)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 6565c8d77..bdb17e1d6 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -6,7 +6,9 @@
"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import visitors
from sqlalchemy.orm import mapper, attributes
from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
from sqlalchemy.orm import session as sessionlib
@@ -292,7 +294,7 @@ class LazyLoader(AbstractRelationLoader):
(criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
- class Visitor(sql.ClauseVisitor):
+ class Visitor(visitors.ClauseVisitor):
def visit_bindparam(s, bindparam):
mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
if bindparam.key in bind_to_col:
@@ -396,7 +398,7 @@ class LazyLoader(AbstractRelationLoader):
if not isinstance(expr, sql.ColumnElement):
return None
columns = []
- class FindColumnInColumnClause(sql.ClauseVisitor):
+ class FindColumnInColumnClause(visitors.ClauseVisitor):
def visit_column(self, c):
columns.append(c)
FindColumnInColumnClause().traverse(expr)
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index cf48202b0..49661a95e 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -10,9 +10,9 @@ clause that compares column values.
"""
from sqlalchemy import sql, schema, exceptions
+from sqlalchemy.sql import visitors, operators
from sqlalchemy import logging
from sqlalchemy.orm import util as mapperutil
-import operator
ONETOMANY = 0
MANYTOONE = 1
@@ -43,7 +43,7 @@ class ClauseSynchronizer(object):
def compile_binary(binary):
"""Assemble a SyncRule given a single binary condition."""
- if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
source_column = None
@@ -144,7 +144,7 @@ class SyncRule(object):
SyncRule.logger = logging.class_logger(SyncRule)
-class BinaryVisitor(sql.ClauseVisitor):
+class BinaryVisitor(visitors.ClauseVisitor):
def __init__(self, func):
self.func = func
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 6b0956dc4..b3f58c954 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -4,7 +4,9 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions, sql_util
+from sqlalchemy import sql, util, exceptions
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import visitors
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
@@ -161,7 +163,7 @@ class ExtensionCarrier(MapperExtension):
before_delete = _create_do('before_delete')
after_delete = _create_do('after_delete')
-class BinaryVisitor(sql.ClauseVisitor):
+class BinaryVisitor(visitors.ClauseVisitor):
def __init__(self, func):
self.func = func
@@ -196,7 +198,7 @@ class AliasedClauses(object):
# for column-level subqueries, swap out its selectable with our
# eager version as appropriate, and manually build the
# "correlation" list of the subquery.
- class ModifySubquery(sql.ClauseVisitor):
+ class ModifySubquery(visitors.ClauseVisitor):
def visit_select(s, select):
select._should_correlate = False
select.append_correlation(self.alias)
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 99803d665..99ca2389b 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -18,8 +18,11 @@ objects as well as the visitor interface, so that the schema package
"""
import re, inspect
-from sqlalchemy import sql, types, exceptions, util, databases
+from sqlalchemy import types, exceptions, util, databases
+from sqlalchemy.sql import expression, visitors
import sqlalchemy
+
+
URL = None
__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
@@ -31,7 +34,7 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
class SchemaItem(object):
"""Base class for items that define a database schema."""
- __metaclass__ = sql._FigureVisitName
+ __metaclass__ = expression._FigureVisitName
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
@@ -84,7 +87,7 @@ def _get_table_key(name, schema):
else:
return schema + "." + name
-class _TableSingleton(sql._FigureVisitName):
+class _TableSingleton(expression._FigureVisitName):
"""A metaclass used by the ``Table`` object to provide singleton behavior."""
def __call__(self, name, metadata, *args, **kwargs):
@@ -124,10 +127,10 @@ class _TableSingleton(sql._FigureVisitName):
return table
-class Table(SchemaItem, sql.TableClause):
+class Table(SchemaItem, expression.TableClause):
"""Represent a relational database table.
- This subclasses ``sql.TableClause`` to provide a table that is
+ This subclasses ``expression.TableClause`` to provide a table that is
associated with an instance of ``MetaData``, which in turn
may be associated with an instance of ``Engine``.
@@ -229,7 +232,7 @@ class Table(SchemaItem, sql.TableClause):
self.schema = kwargs.pop('schema', None)
self.indexes = util.Set()
self.constraints = util.Set()
- self._columns = sql.ColumnCollection()
+ self._columns = expression.ColumnCollection()
self.primary_key = PrimaryKeyConstraint()
self._foreign_keys = util.OrderedSet()
self.quote = kwargs.pop('quote', False)
@@ -291,7 +294,7 @@ class Table(SchemaItem, sql.TableClause):
def get_children(self, column_collections=True, schema_visitor=False, **kwargs):
if not schema_visitor:
- return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs)
+ return expression.TableClause.get_children(self, column_collections=column_collections, **kwargs)
else:
if column_collections:
return [c for c in self.columns]
@@ -338,10 +341,10 @@ class Table(SchemaItem, sql.TableClause):
args.append(c.copy())
return Table(self.name, metadata, schema=schema, *args)
-class Column(SchemaItem, sql._ColumnClause):
+class Column(SchemaItem, expression._ColumnClause):
"""Represent a column in a database table.
- This is a subclass of ``sql.ColumnClause`` and represents an
+ This is a subclass of ``expression.ColumnClause`` and represents an
actual existing table in the database, in a similar fashion as
``TableClause``/``Table``.
"""
@@ -575,7 +578,7 @@ class Column(SchemaItem, sql._ColumnClause):
return [x for x in (self.default, self.onupdate) if x is not None] + \
list(self.foreign_keys) + list(self.constraints)
else:
- return sql._ColumnClause.get_children(self, **kwargs)
+ return expression._ColumnClause.get_children(self, **kwargs)
class ForeignKey(SchemaItem):
@@ -806,7 +809,7 @@ class Constraint(SchemaItem):
def __init__(self, name=None):
self.name = name
- self.columns = sql.ColumnCollection()
+ self.columns = expression.ColumnCollection()
def __contains__(self, x):
return self.columns.contains_column(x)
@@ -1124,12 +1127,12 @@ class MetaData(SchemaItem):
del self.tables[table.key]
def table_iterator(self, reverse=True, tables=None):
- import sqlalchemy.sql_util
+ from sqlalchemy.sql import util as sql_util
if tables is None:
tables = self.tables.values()
else:
tables = util.Set(tables).intersection(self.tables.values())
- sorter = sqlalchemy.sql_util.TableCollection(list(tables))
+ sorter = sql_util.TableCollection(list(tables))
return iter(sorter.sort(reverse=reverse))
def _get_parent(self):
@@ -1356,7 +1359,7 @@ class ThreadLocalMetaData(MetaData):
e.dispose()
-class SchemaVisitor(sql.ClauseVisitor):
+class SchemaVisitor(visitors.ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
__traverse_options__ = {'schema_visitor':True}
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
new file mode 100644
index 000000000..06b9f1f75
--- /dev/null
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -0,0 +1,3 @@
+from sqlalchemy.sql.expression import *
+from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor
+
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/sql/compiler.py
index 5f5e1c171..6053c72be 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1,22 +1,18 @@
-# ansisql.py
+# compiler.py
# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Defines ANSI SQL operations.
+"""SQL expression compilation routines and DDL implementations."""
-Contains default implementations for the abstract objects in the sql
-module.
-"""
+import string, re
+from sqlalchemy import schema, engine, util, exceptions
+from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import expression as sql
-import string, re, sets, operator
-
-from sqlalchemy import schema, sql, engine, util, exceptions, operators
-from sqlalchemy.engine import default
-
-
-ANSI_FUNCS = sets.ImmutableSet([
+ANSI_FUNCS = util.Set([
'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
'SESSION_USER', 'USER'])
@@ -77,7 +73,6 @@ OPERATORS = {
operators.comma_op : ', ',
operators.desc_op : 'DESC',
operators.asc_op : 'ASC',
-
operators.from_ : 'FROM',
operators.as_ : 'AS',
operators.exists : 'EXISTS',
@@ -85,36 +80,10 @@ OPERATORS = {
operators.isnot : 'IS NOT'
}
-class ANSIDialect(default.DefaultDialect):
- def __init__(self, cache_identifiers=True, **kwargs):
- super(ANSIDialect,self).__init__(**kwargs)
- self.identifier_preparer = self.preparer()
- self.cache_identifiers = cache_identifiers
-
- def create_connect_args(self):
- return ([],{})
-
- def schemagenerator(self, *args, **kwargs):
- return ANSISchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return ANSISchemaDropper(self, *args, **kwargs)
-
- def compiler(self, statement, parameters, **kwargs):
- return ANSICompiler(self, statement, parameters, **kwargs)
-
- def preparer(self):
- """Return an IdentifierPreparer.
-
- This object is used to format table and column names including
- proper quoting and case conventions.
- """
- return ANSIIdentifierPreparer(self)
-
-class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
+class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
"""Default implementation of Compiled.
- Compiles ClauseElements into ANSI-compliant SQL strings.
+ Compiles ClauseElements into SQL strings.
"""
__traverse_options__ = {'column_collections':False, 'entry':True}
@@ -122,7 +91,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
operators = OPERATORS
def __init__(self, dialect, statement, parameters=None, **kwargs):
- """Construct a new ``ANSICompiler`` object.
+ """Construct a new ``DefaultCompiler`` object.
dialect
Dialect to be used
@@ -139,7 +108,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
correspond to the keys present in the parameters.
"""
- super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
+ super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs)
# if we are insert/update. set to true when we visit an INSERT or UPDATE
self.isinsert = self.isupdate = False
@@ -170,17 +139,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
self.bindtemplate = ":%s"
# paramstyle from the dialect (comes from DBAPI)
- self.paramstyle = dialect.paramstyle
+ self.paramstyle = self.dialect.paramstyle
# true if the paramstyle is positional
- self.positional = dialect.positional
+ self.positional = self.dialect.positional
# a list of the compiled's bind parameter names, used to help
# formulate a positional argument list
self.positiontup = []
- # an ANSIIdentifierPreparer that formats the quoting of identifiers
- self.preparer = dialect.identifier_preparer
+ # an IdentifierPreparer that formats the quoting of identifiers
+ self.preparer = self.dialect.identifier_preparer
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
@@ -244,7 +213,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
return None
def construct_params(self, params):
- """Return a sql.ClauseParameters object.
+ """Return a sql.util.ClauseParameters object.
Combines the given bind parameter dictionary (string keys to object values)
with the _BindParamClause objects stored within this Compiled object
@@ -252,7 +221,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
for a single statement execution, or one element of an executemany execution.
"""
- d = sql.ClauseParameters(self.dialect, self.positiontup)
+ d = sql_util.ClauseParameters(self.dialect, self.positiontup)
pd = self.parameters or {}
pd.update(params)
@@ -781,7 +750,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
def __str__(self):
return self.string
-class ANSISchemaBase(engine.SchemaIterator):
+class DDLBase(engine.SchemaIterator):
def find_alterables(self, tables):
alterables = []
class FindAlterables(schema.SchemaVisitor):
@@ -794,12 +763,12 @@ class ANSISchemaBase(engine.SchemaIterator):
findalterables.traverse(c)
return alterables
-class ANSISchemaGenerator(ANSISchemaBase):
+class SchemaGenerator(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaGenerator, self).__init__(connection, **kwargs)
+ super(SchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables and util.Set(tables) or None
- self.preparer = dialect.preparer()
+ self.preparer = dialect.identifier_preparer
self.dialect = dialect
def get_column_specification(self, column, first_pk=False):
@@ -860,7 +829,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def _compile(self, tocompile, parameters):
"""compile the given string/parameters using this SchemaGenerator's dialect."""
- compiler = self.dialect.compiler(tocompile, parameters)
+ compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
compiler.compile()
return compiler
@@ -930,12 +899,12 @@ class ANSISchemaGenerator(ANSISchemaBase):
string.join([preparer.format_column(c) for c in index.columns], ', ')))
self.execute()
-class ANSISchemaDropper(ANSISchemaBase):
+class SchemaDropper(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaDropper, self).__init__(connection, **kwargs)
+ super(SchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
- self.preparer = dialect.preparer()
+ self.preparer = dialect.identifier_preparer
self.dialect = dialect
def visit_metadata(self, metadata):
@@ -964,14 +933,11 @@ class ANSISchemaDropper(ANSISchemaBase):
self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
-class ANSIDefaultRunner(engine.DefaultRunner):
- pass
-
-class ANSIIdentifierPreparer(object):
+class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
- """Construct a new ``ANSIIdentifierPreparer`` object.
+ """Construct a new ``IdentifierPreparer`` object.
initial_quote
Character that begins a delimited identifier.
@@ -1049,20 +1015,14 @@ class ANSIIdentifierPreparer(object):
def __generic_obj_format(self, obj, ident):
if getattr(obj, 'quote', False):
return self.quote_identifier(ident)
- if self.dialect.cache_identifiers:
- try:
- return self.__strings[ident]
- except KeyError:
- if self._requires_quotes(ident):
- self.__strings[ident] = self.quote_identifier(ident)
- else:
- self.__strings[ident] = ident
- return self.__strings[ident]
- else:
+ try:
+ return self.__strings[ident]
+ except KeyError:
if self._requires_quotes(ident):
- return self.quote_identifier(ident)
+ self.__strings[ident] = self.quote_identifier(ident)
else:
- return ident
+ self.__strings[ident] = ident
+ return self.__strings[ident]
def should_quote(self, object):
return object.quote or self._requires_quotes(object.name)
@@ -1152,5 +1112,3 @@ class ANSIIdentifierPreparer(object):
return [self._unescape_identifier(i)
for i in [a or b for a, b in r.findall(identifiers)]]
-
-dialect = ANSIDialect
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql/expression.py
index 7c73f7cb7..e117e3f47 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -26,13 +26,14 @@ classes usually have few or no public methods and are less guaranteed
to stay the same in future releases.
"""
-from sqlalchemy import util, exceptions, operators
+from sqlalchemy import util, exceptions
+from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
import re
__all__ = [
- 'Alias', 'ClauseElement', 'ClauseParameters',
- 'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
+ 'Alias', 'ClauseElement',
+ 'ColumnCollection', 'ColumnElement',
'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
'between', 'bindparam', 'case', 'cast', 'column', 'delete',
@@ -810,187 +811,6 @@ def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
-class ClauseParameters(object):
- """Represent a dictionary/iterator of bind parameter key names/values.
-
- Tracks the original [sqlalchemy.sql#_BindParamClause] objects
- as well as the keys/position of each parameter, and can return
- parameters as a dictionary or a list. Will process parameter
- values according to the ``TypeEngine`` objects present in the
- ``_BindParamClause`` instances.
- """
-
- def __init__(self, dialect, positional=None):
- self.dialect = dialect
- self.__binds = {}
- self.positional = positional or []
-
- def get_parameter(self, key):
- return self.__binds[key]
-
- def set_parameter(self, bindparam, value, name):
- self.__binds[name] = [bindparam, name, value]
-
- def get_original(self, key):
- return self.__binds[key][2]
-
- def get_type(self, key):
- return self.__binds[key][0].type
-
- def get_processors(self):
- """return a dictionary of bind 'processing' functions"""
- return dict([
- (key, value) for key, value in
- [(
- key,
- self.__binds[key][0].bind_processor(self.dialect)
- ) for key in self.__binds]
- if value is not None
- ])
-
- def get_processed(self, key, processors):
- return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
-
- def keys(self):
- return self.__binds.keys()
-
- def __iter__(self):
- return iter(self.keys())
-
- def __getitem__(self, key):
- (bind, name, value) = self.__binds[key]
- processor = bind.bind_processor(self.dialect)
- return processor is not None and processor(value) or value
-
- def __contains__(self, key):
- return key in self.__binds
-
- def set_value(self, key, value):
- self.__binds[key][2] = value
-
- def get_original_dict(self):
- return dict([(name, value) for (b, name, value) in self.__binds.values()])
-
- def __get_processed(self, key, processors):
- if key in processors:
- return processors[key](self.__binds[key][2])
- else:
- return self.__binds[key][2]
-
- def get_raw_list(self, processors):
- return [self.__get_processed(key, processors) for key in self.positional]
-
- def get_raw_dict(self, processors, encode_keys=False):
- if encode_keys:
- return dict([
- (
- key.encode(self.dialect.encoding),
- self.__get_processed(key, processors)
- )
- for key in self.keys()
- ])
- else:
- return dict([
- (
- key,
- self.__get_processed(key, processors)
- )
- for key in self.keys()
- ])
-
- def __repr__(self):
- return self.__class__.__name__ + ":" + repr(self.get_original_dict())
-
-class ClauseVisitor(object):
- """A class that knows how to traverse and visit ``ClauseElements``.
-
- Calls visit_XXX() methods dynamically generated for each
- particualr ``ClauseElement`` subclass encountered. Traversal of a
- hierarchy of ``ClauseElements`` is achieved via the ``traverse()``
- method, which is passed the lead ``ClauseElement``.
-
- By default, ``ClauseVisitor`` traverses all elements fully.
- Options can be specified at the class level via the
- ``__traverse_options__`` dictionary which will be passed to the
- ``get_children()`` method of each ``ClauseElement``; these options
- can indicate modifications to the set of elements returned, such
- as to not return column collections (column_collections=False) or
- to return Schema-level items (schema_visitor=True).
-
- ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
- operation, which will produce a copy of a given ``ClauseElement``
- structure while at the same time allowing ``ClauseVisitor``
- subclasses to modify the new structure in-place.
- """
-
- __traverse_options__ = {}
-
- def traverse_single(self, obj, **kwargs):
- meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
- if meth:
- return meth(obj, **kwargs)
-
- def iterate(self, obj, stop_on=None):
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- yield t
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
-
- def traverse(self, obj, stop_on=None, clone=False):
- if clone:
- obj = obj._clone()
-
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- if clone:
- t._copy_internals()
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- v = self
- while v is not None:
- meth = getattr(v, "visit_%s" % target.__visit_name__, None)
- if meth:
- meth(target)
- v = getattr(v, '_next', None)
- return obj
-
- def chain(self, visitor):
- """'chain' an additional ClauseVisitor onto this ClauseVisitor.
-
- The chained visitor will receive all visit events after this one.
- """
-
- tail = self
- while getattr(tail, '_next', None) is not None:
- tail = tail._next
- tail._next = visitor
- return self
-
-class NoColumnVisitor(ClauseVisitor):
- """A ClauseVisitor that will not traverse exported column collections.
-
- Will not traverse the exported Column collections on Table, Alias,
- Select, and CompoundSelect objects (i.e. their 'columns' or 'c'
- attribute).
-
- This is useful because most traversals don't need those columns,
- or in the case of ANSICompiler it traverses them explicitly; so
- skipping their traversal here greatly cuts down on method call
- overhead.
- """
-
- __traverse_options__ = {'column_collections': False}
-
class _FigureVisitName(type):
def __init__(cls, clsname, bases, dict):
@@ -1061,7 +881,7 @@ class ClauseElement(object):
elif len(optionaldict) > 1:
raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument")
- class Vis(ClauseVisitor):
+ class Vis(visitors.ClauseVisitor):
def visit_bindparam(self, bind):
if bind.key in kwargs:
bind.value = kwargs[bind.key]
@@ -1156,7 +976,7 @@ class ClauseElement(object):
if any.
Finally, if there is no bound ``Engine``, uses an
- ``ANSIDialect`` to create a default ``Compiler``.
+ ``DefaultDialect`` to create a default ``Compiler``.
`parameters` is a dictionary representing the default bind
parameters to be used with the statement. If `parameters` is
@@ -1175,15 +995,16 @@ class ClauseElement(object):
if compiler is None:
if dialect is not None:
- compiler = dialect.compiler(self, parameters)
+ compiler = dialect.statement_compiler(dialect, self, parameters)
elif bind is not None:
- compiler = bind.compiler(self, parameters)
+ compiler = bind.statement_compiler(self, parameters)
elif self.bind is not None:
- compiler = self.bind.compiler(self, parameters)
+ compiler = self.bind.statement_compiler(self, parameters)
if compiler is None:
- import sqlalchemy.ansisql as ansisql
- compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters)
+ from sqlalchemy.engine.default import DefaultDialect
+ dialect = DefaultDialect()
+ compiler = dialect.statement_compiler(dialect, self, parameters=parameters)
compiler.compile()
return compiler
@@ -1727,7 +1548,7 @@ class FromClause(Selectable):
def _get_all_embedded_columns(self):
ret = []
- class FindCols(ClauseVisitor):
+ class FindCols(visitors.ClauseVisitor):
def visit_column(self, col):
ret.append(col)
FindCols().traverse(self)
@@ -1744,8 +1565,8 @@ class FromClause(Selectable):
def replace_selectable(self, old, alias):
"""replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
- from sqlalchemy import sql_util
- return sql_util.ClauseAdapter(alias).traverse(self, clone=True)
+ from sqlalchemy.sql import util
+ return util.ClauseAdapter(alias).traverse(self, clone=True)
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
"""Given a ``ColumnElement``, return the exported ``ColumnElement``
@@ -2376,7 +2197,7 @@ class Join(FromClause):
else:
equivs[x] = util.Set([y])
- class BinaryVisitor(ClauseVisitor):
+ class BinaryVisitor(visitors.ClauseVisitor):
def visit_binary(self, binary):
if binary.operator == operators.eq:
add_equiv(binary.left, binary.right)
@@ -2460,7 +2281,7 @@ class Join(FromClause):
return self.__folded_equivalents
if equivs is None:
equivs = util.Set()
- class LocateEquivs(NoColumnVisitor):
+ class LocateEquivs(visitors.NoColumnVisitor):
def visit_binary(self, binary):
if binary.operator == operators.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
@@ -3331,7 +3152,7 @@ class Select(_SelectBaseMixin, FromClause):
return intersect_all(self, other, **kwargs)
def _table_iterator(self):
- for t in NoColumnVisitor().iterate(self):
+ for t in visitors.NoColumnVisitor().iterate(self):
if isinstance(t, TableClause):
yield t
diff --git a/lib/sqlalchemy/operators.py b/lib/sqlalchemy/sql/operators.py
index b8aca3d26..b8aca3d26 100644
--- a/lib/sqlalchemy/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql/util.py
index cc6325822..2c7294e66 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -1,7 +1,100 @@
-from sqlalchemy import sql, util, schema, topological
+from sqlalchemy import util, schema, topological
+from sqlalchemy.sql import expression, visitors
"""Utility functions that build upon SQL and Schema constructs."""
+class ClauseParameters(object):
+ """Represent a dictionary/iterator of bind parameter key names/values.
+
+ Tracks the original [sqlalchemy.sql#_BindParamClause] objects as well as the
+ keys/position of each parameter, and can return parameters as a
+ dictionary or a list. Will process parameter values according to
+ the ``TypeEngine`` objects present in the ``_BindParamClause`` instances.
+ """
+
+ def __init__(self, dialect, positional=None):
+ self.dialect = dialect
+ self.__binds = {}
+ self.positional = positional or []
+
+ def get_parameter(self, key):
+ return self.__binds[key]
+
+ def set_parameter(self, bindparam, value, name):
+ self.__binds[name] = [bindparam, name, value]
+
+ def get_original(self, key):
+ return self.__binds[key][2]
+
+ def get_type(self, key):
+ return self.__binds[key][0].type
+
+ def get_processors(self):
+ """return a dictionary of bind 'processing' functions"""
+ return dict([
+ (key, value) for key, value in
+ [(
+ key,
+ self.__binds[key][0].bind_processor(self.dialect)
+ ) for key in self.__binds]
+ if value is not None
+ ])
+
+ def get_processed(self, key, processors):
+ return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
+
+ def keys(self):
+ return self.__binds.keys()
+
+ def __iter__(self):
+ return iter(self.keys())
+
+ def __getitem__(self, key):
+ (bind, name, value) = self.__binds[key]
+ processor = bind.bind_processor(self.dialect)
+ return processor is not None and processor(value) or value
+
+ def __contains__(self, key):
+ return key in self.__binds
+
+ def set_value(self, key, value):
+ self.__binds[key][2] = value
+
+ def get_original_dict(self):
+ return dict([(name, value) for (b, name, value) in self.__binds.values()])
+
+ def __get_processed(self, key, processors):
+ if key in processors:
+ return processors[key](self.__binds[key][2])
+ else:
+ return self.__binds[key][2]
+
+ def get_raw_list(self, processors):
+ return [self.__get_processed(key, processors) for key in self.positional]
+
+ def get_raw_dict(self, processors, encode_keys=False):
+ if encode_keys:
+ return dict([
+ (
+ key.encode(self.dialect.encoding),
+ self.__get_processed(key, processors)
+ )
+ for key in self.keys()
+ ])
+ else:
+ return dict([
+ (
+ key,
+ self.__get_processed(key, processors)
+ )
+ for key in self.keys()
+ ])
+
+ def __repr__(self):
+ return self.__class__.__name__ + ":" + repr(self.get_original_dict())
+
+
+
class TableCollection(object):
def __init__(self, tables=None):
self.tables = tables or []
@@ -64,7 +157,7 @@ class TableCollection(object):
return sequence
-class TableFinder(TableCollection, sql.NoColumnVisitor):
+class TableFinder(TableCollection, visitors.NoColumnVisitor):
"""locate all Tables within a clause."""
def __init__(self, clause, check_columns=False, include_aliases=False):
@@ -85,7 +178,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor):
if self.check_columns:
self.tables.append(column.table)
-class ColumnFinder(sql.ClauseVisitor):
+class ColumnFinder(visitors.ClauseVisitor):
def __init__(self):
self.columns = util.Set()
@@ -95,7 +188,7 @@ class ColumnFinder(sql.ClauseVisitor):
def __iter__(self):
return iter(self.columns)
-class ColumnsInClause(sql.ClauseVisitor):
+class ColumnsInClause(visitors.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns
from the clause are in the selectable.
"""
@@ -108,7 +201,7 @@ class ColumnsInClause(sql.ClauseVisitor):
if self.selectable.c.get(column.key) is column:
self.result = True
-class AbstractClauseProcessor(sql.NoColumnVisitor):
+class AbstractClauseProcessor(visitors.NoColumnVisitor):
"""Traverse a clause and attempt to convert the contents of container elements
to a converted element.
@@ -224,10 +317,10 @@ class ClauseAdapter(AbstractClauseProcessor):
self.equivalents = equivalents
def convert_element(self, col):
- if isinstance(col, sql.FromClause):
+ if isinstance(col, expression.FromClause):
if self.selectable.is_derived_from(col):
return self.selectable
- if not isinstance(col, sql.ColumnElement):
+ if not isinstance(col, expression.ColumnElement):
return None
if self.include is not None:
if col not in self.include:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
new file mode 100644
index 000000000..98e4de6c3
--- /dev/null
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -0,0 +1,87 @@
+class ClauseVisitor(object):
+ """A class that knows how to traverse and visit
+ ``ClauseElements``.
+
+ Calls visit_XXX() methods dynamically generated for each particualr
+ ``ClauseElement`` subclass encountered. Traversal of a
+ hierarchy of ``ClauseElements`` is achieved via the
+ ``traverse()`` method, which is passed the lead
+ ``ClauseElement``.
+
+ By default, ``ClauseVisitor`` traverses all elements
+ fully. Options can be specified at the class level via the
+ ``__traverse_options__`` dictionary which will be passed
+ to the ``get_children()`` method of each ``ClauseElement``;
+ these options can indicate modifications to the set of
+ elements returned, such as to not return column collections
+ (column_collections=False) or to return Schema-level items
+ (schema_visitor=True).
+
+ ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+ operation, which will produce a copy of a given ``ClauseElement``
+ structure while at the same time allowing ``ClauseVisitor`` subclasses
+ to modify the new structure in-place.
+
+ """
+ __traverse_options__ = {}
+
+ def traverse_single(self, obj, **kwargs):
+ meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kwargs)
+
+ def iterate(self, obj, stop_on=None):
+ stack = [obj]
+ traversal = []
+ while len(stack) > 0:
+ t = stack.pop()
+ if stop_on is None or t not in stop_on:
+ yield t
+ traversal.insert(0, t)
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
+
+ def traverse(self, obj, stop_on=None, clone=False):
+ if clone:
+ obj = obj._clone()
+
+ stack = [obj]
+ traversal = []
+ while len(stack) > 0:
+ t = stack.pop()
+ if stop_on is None or t not in stop_on:
+ traversal.insert(0, t)
+ if clone:
+ t._copy_internals()
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
+ for target in traversal:
+ v = self
+ while v is not None:
+ meth = getattr(v, "visit_%s" % target.__visit_name__, None)
+ if meth:
+ meth(target)
+ v = getattr(v, '_next', None)
+ return obj
+
+ def chain(self, visitor):
+ """'chain' an additional ClauseVisitor onto this ClauseVisitor.
+
+ the chained visitor will receive all visit events after this one."""
+ tail = self
+ while getattr(tail, '_next', None) is not None:
+ tail = tail._next
+ tail._next = visitor
+ return self
+
+class NoColumnVisitor(ClauseVisitor):
+ """a ClauseVisitor that will not traverse the exported Column
+ collections on Table, Alias, Select, and CompoundSelect objects
+ (i.e. their 'columns' or 'c' attribute).
+
+ this is useful because most traversals don't need those columns, or
+ in the case of DefaultCompiler it traverses them explicitly; so
+ skipping their traversal here greatly cuts down on method call overhead.
+ """
+
+ __traverse_options__ = {'column_collections':False}
diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py
index 46e0a7137..294842854 100644
--- a/test/dialect/mysql.py
+++ b/test/dialect/mysql.py
@@ -154,7 +154,7 @@ class TypesTest(AssertMixin):
table_args.append(Column('c%s' % index, type_(*args, **kw)))
numeric_table = Table(*table_args)
- gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
+ gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None)
for col in numeric_table.c:
index = int(col.name[1:])
@@ -238,7 +238,7 @@ class TypesTest(AssertMixin):
table_args.append(Column('c%s' % index, type_(*args, **kw)))
charset_table = Table(*table_args)
- gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
+ gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None)
for col in charset_table.c:
index = int(col.name[1:])
@@ -707,7 +707,7 @@ class SQLTest(AssertMixin):
def colspec(c):
- return testbase.db.dialect.schemagenerator(
+ return testbase.db.dialect.schemagenerator(testbase.db.dialect,
testbase.db, None, None).get_column_specification(c)
if __name__ == "__main__":
diff --git a/test/engine/reflection.py b/test/engine/reflection.py
index da6f75149..2345a328a 100644
--- a/test/engine/reflection.py
+++ b/test/engine/reflection.py
@@ -2,7 +2,6 @@ import testbase
import pickle, StringIO, unicodedata
from sqlalchemy import *
-import sqlalchemy.ansisql as ansisql
from sqlalchemy.exceptions import NoSuchTableError
from testlib import *
from testlib import engines
@@ -686,7 +685,7 @@ class SchemaTest(PersistTest):
def foo(s, p=None):
buf.write(s)
gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo)
- gen = gen.dialect.schemagenerator(gen)
+ gen = gen.dialect.schemagenerator(gen.dialect, gen)
gen.traverse(table1)
gen.traverse(table2)
buf = buf.getvalue()
diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py
index 0c824d372..2e84c83d1 100644
--- a/test/orm/dynamic.py
+++ b/test/orm/dynamic.py
@@ -1,7 +1,6 @@
import testbase
import operator
from sqlalchemy import *
-from sqlalchemy import ansisql
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
diff --git a/test/orm/query.py b/test/orm/query.py
index e0b8bf4f3..e3f6ed42c 100644
--- a/test/orm/query.py
+++ b/test/orm/query.py
@@ -1,7 +1,8 @@
import testbase
import operator
+from sqlalchemy.sql import compiler
from sqlalchemy import *
-from sqlalchemy import ansisql
+from sqlalchemy.engine import default
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
@@ -141,7 +142,7 @@ class OperatorTest(QueryTest):
"""test sql.Comparator implementation for MapperProperties"""
def _test(self, clause, expected):
- c = str(clause.compile(dialect=ansisql.ANSIDialect()))
+ c = str(clause.compile(dialect = default.DefaultDialect()))
assert c == expected, "%s != %s" % (c, expected)
def test_arithmetic(self):
@@ -182,7 +183,7 @@ class OperatorTest(QueryTest):
# the compiled clause should match either (e.g.):
# 'a' < 'b' -or- 'b' > 'a'.
- compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect()))
+ compiled = str(py_op(lhs, rhs).compile(dialect=default.DefaultDialect()))
fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
@@ -201,7 +202,7 @@ class OperatorTest(QueryTest):
# this one would require adding compile() to InstrumentedScalarAttribute. do we want this ?
#(User.id, "users.id")
):
- c = expr.compile(dialect=ansisql.ANSIDialect())
+ c = expr.compile(dialect=default.DefaultDialect())
assert str(c) == compare, "%s != %s" % (str(c), compare)
diff --git a/test/sql/constraints.py b/test/sql/constraints.py
index 3120185d5..a8b642b9b 100644
--- a/test/sql/constraints.py
+++ b/test/sql/constraints.py
@@ -179,7 +179,7 @@ class ConstraintTest(AssertMixin):
capt.append(repr(context.parameters))
ex(context)
connection._Connection__execute = proxy
- schemagen = testbase.db.dialect.schemagenerator(connection)
+ schemagen = testbase.db.dialect.schemagenerator(testbase.db.dialect, connection)
schemagen.traverse(events)
assert capt[0].strip().startswith('CREATE TABLE events')
diff --git a/test/sql/generative.py b/test/sql/generative.py
index d79af577f..9bd97b305 100644
--- a/test/sql/generative.py
+++ b/test/sql/generative.py
@@ -1,6 +1,7 @@
import testbase
from sqlalchemy import *
from testlib import *
+from sqlalchemy.sql.visitors import *
class TraversalTest(AssertMixin):
"""test ClauseVisitor's traversal, particularly its ability to copy and modify
@@ -213,7 +214,7 @@ class ClauseTest(SQLCompileTest):
self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2")
def test_clause_adapter(self):
- from sqlalchemy import sql_util
+ from sqlalchemy.sql import util as sql_util
t1alias = t1.alias('t1alias')
diff --git a/test/sql/labels.py b/test/sql/labels.py
index 553a3a3bc..dee76428d 100644
--- a/test/sql/labels.py
+++ b/test/sql/labels.py
@@ -78,7 +78,7 @@ class LongLabelsTest(PersistTest):
# this is the test that fails if the "max identifier length" is shorter than the
# length of the actual columns created, because the column names get truncated.
# if you try to separate "physical columns" from "labels", and only truncate the labels,
- # the ansisql.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code,
+ # the compiler.DefaultCompiler.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code,
# since it is creating "labels" on the fly but not affecting derived columns, which think they are
# still "physical"
q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias('foo')
diff --git a/test/sql/query.py b/test/sql/query.py
index 4f569f1c0..8ec3190b4 100644
--- a/test/sql/query.py
+++ b/test/sql/query.py
@@ -1,7 +1,8 @@
import testbase
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions, ansisql
+from sqlalchemy import exceptions
+from sqlalchemy.engine import default
from testlib import *
@@ -166,14 +167,14 @@ class QueryTest(PersistTest):
assert len(r) == 1
def test_bindparam_detection(self):
- dialect = ansisql.ANSIDialect(default_paramstyle='qmark')
- prep = lambda q: dialect.compile(sql.text(q)).string
+ dialect = default.DefaultDialect(default_paramstyle='qmark')
+ prep = lambda q: str(sql.text(q).compile(dialect=dialect))
def a_eq(got, wanted):
if got != wanted:
print "Wanted %s" % wanted
print "Received %s" % got
- self.assert_(got == wanted)
+ self.assert_(got == wanted, got)
a_eq(prep('select foo'), 'select foo')
a_eq(prep("time='12:30:00'"), "time='12:30:00'")
diff --git a/test/sql/quote.py b/test/sql/quote.py
index ad25619df..0c414af3a 100644
--- a/test/sql/quote.py
+++ b/test/sql/quote.py
@@ -1,7 +1,7 @@
import testbase
from sqlalchemy import *
from testlib import *
-
+from sqlalchemy.sql import compiler
class QuoteTest(PersistTest):
def setUpAll(self):
@@ -98,10 +98,10 @@ class QuoteTest(PersistTest):
class PreparerTest(PersistTest):
- """Test the db-agnostic quoting services of ANSIIdentifierPreparer."""
+ """Test the db-agnostic quoting services of IdentifierPreparer."""
def test_unformat(self):
- prep = ansisql.ANSIIdentifierPreparer(None)
+ prep = compiler.IdentifierPreparer(None)
unformat = prep.unformat_identifiers
def a_eq(have, want):
@@ -120,7 +120,7 @@ class PreparerTest(PersistTest):
a_eq(unformat('"foo"."b""a""r"."baz"'), ['foo', 'b"a"r', 'baz'])
def test_unformat_custom(self):
- class Custom(ansisql.ANSIIdentifierPreparer):
+ class Custom(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(Custom, self).__init__(dialect, initial_quote='`',
final_quote='`')
diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py
index c497fbcbd..4ffb4c591 100644
--- a/test/sql/testtypes.py
+++ b/test/sql/testtypes.py
@@ -195,7 +195,7 @@ class ColumnsTest(AssertMixin):
)
for aCol in testTable.c:
- self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None, None).get_column_specification(aCol))
+ self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db.dialect, db, None, None).get_column_specification(aCol))
class UnicodeTest(AssertMixin):
"""tests the Unicode type. also tests the TypeDecorator with instances in the types package."""
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
index 6830fb63c..58fd7c0d1 100644
--- a/test/testlib/testing.py
+++ b/test/testlib/testing.py
@@ -161,7 +161,8 @@ class ExecutionContextWrapper(object):
if params is not None and isinstance(params, list) and len(params) == 1:
params = params[0]
- if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
+ from sqlalchemy.sql.util import ClauseParameters
+ if isinstance(ctx.compiled_parameters, ClauseParameters):
parameters = ctx.compiled_parameters.get_original_dict()
elif isinstance(ctx.compiled_parameters, list):
parameters = [p.get_original_dict() for p in ctx.compiled_parameters]