diff options
36 files changed, 517 insertions, 607 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 6bf8b96e9..4aa773239 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -347,7 +347,7 @@ class AccessDialect(ansisql.ANSIDialect): return names -class AccessCompiler(ansisql.ANSICompiler): +class AccessCompiler(compiler.DefaultCompiler): def visit_select_precolumns(self, select): """Access puts TOP, it's version of LIMIT here """ s = select.distinct and "DISTINCT " or "" @@ -387,7 +387,7 @@ class AccessCompiler(ansisql.ANSICompiler): return '' -class AccessSchemaGenerator(ansisql.ANSISchemaGenerator): +class AccessSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -410,7 +410,7 @@ class AccessSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec -class AccessSchemaDropper(ansisql.ANSISchemaDropper): +class AccessSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name)) self.execute() @@ -418,7 +418,7 @@ class AccessSchemaDropper(ansisql.ANSISchemaDropper): class AccessDefaultRunner(ansisql.ANSIDefaultRunner): pass -class AccessIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class AccessIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 307fceb48..9cccb53e8 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -7,9 +7,10 @@ import warnings -from sqlalchemy import util, sql, schema, ansisql, exceptions -import sqlalchemy.engine.default as default -import sqlalchemy.types as sqltypes +from sqlalchemy import util, sql, schema, exceptions +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes _initialized_kb = False @@ -99,9 +100,9 @@ class FBExecutionContext(default.DefaultExecutionContext): return True -class FBDialect(ansisql.ANSIDialect): +class FBDialect(default.DefaultDialect): def __init__(self, type_conv=200, concurrency_level=1, **kwargs): - ansisql.ANSIDialect.__init__(self, **kwargs) + default.DefaultDialect.__init__(self, **kwargs) self.type_conv = type_conv self.concurrency_level= concurrency_level @@ -135,21 +136,6 @@ class FBDialect(ansisql.ANSIDialect): def supports_sane_rowcount(self): return False - def compiler(self, statement, bindparams, **kwargs): - return FBCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return FBSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return FBSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection): - return FBDefaultRunner(connection) - - def preparer(self): - return FBIdentifierPreparer(self) - def max_identifier_length(self): return 31 @@ -307,7 +293,7 @@ class FBDialect(ansisql.ANSIDialect): connection.commit(True) -class FBCompiler(ansisql.ANSICompiler): +class FBCompiler(compiler.DefaultCompiler): """Firebird specific idiosincrasies""" def visit_alias(self, alias, asfrom=False, **kwargs): @@ -346,7 +332,7 @@ class FBCompiler(ansisql.ANSICompiler): return "" -class FBSchemaGenerator(ansisql.ANSISchemaGenerator): +class FBSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -365,13 +351,13 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator): self.execute() -class FBSchemaDropper(ansisql.ANSISchemaDropper): +class FBSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): self.append("DROP GENERATOR %s" % sequence.name) self.execute() -class FBDefaultRunner(ansisql.ANSIDefaultRunner): +class FBDefaultRunner(base.DefaultRunner): def exec_default_sql(self, default): c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection) return self.connection.execute_compiled(c).scalar() @@ -421,7 +407,7 @@ RESERVED_WORDS = util.Set( "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) -class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class FBIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True) @@ -430,3 +416,9 @@ class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer): dialect = FBDialect +dialect.statement_compiler = FBCompiler +dialect.schemagenerator = FBSchemaGenerator +dialect.schemadropper = FBSchemaDropper +dialect.defaultrunner = FBDefaultRunner +dialect.preparer = FBIdentifierPreparer + diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 21ecf1538..ceb56903a 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -7,9 +7,10 @@ import datetime, warnings -from sqlalchemy import sql, schema, ansisql, exceptions, pool -import sqlalchemy.engine.default as default -import sqlalchemy.types as sqltypes +from sqlalchemy import sql, schema, exceptions, pool +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes # for offset @@ -203,11 +204,11 @@ class InfoExecutionContext(default.DefaultExecutionContext): def create_cursor( self ): return informix_cursor( self.connection.connection ) -class InfoDialect(ansisql.ANSIDialect): +class InfoDialect(default.DefaultDialect): def __init__(self, use_ansi=True,**kwargs): self.use_ansi = use_ansi - ansisql.ANSIDialect.__init__(self, **kwargs) + default.DefaultDialect.__init__(self, **kwargs) self.paramstyle = 'qmark' def dbapi(cls): @@ -252,18 +253,6 @@ class InfoDialect(ansisql.ANSIDialect): def oid_column_name(self,column): return "rowid" - def preparer(self): - return InfoIdentifierPreparer(self) - - def compiler(self, statement, bindparams, **kwargs): - return InfoCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return InfoSchemaGenerator( self , *args, **kwargs) - - def schemadropper(self, *args, **params): - return InfoSchemaDroper( self , *args , **params) - def table_names(self, connection, schema): s = "select tabname from systables" return [row[0] for row in connection.execute(s)] @@ -376,14 +365,14 @@ class InfoDialect(ansisql.ANSIDialect): for cons_name, cons_type, local_column in rows: table.primary_key.add( table.c[local_column] ) -class InfoCompiler(ansisql.ANSICompiler): +class InfoCompiler(compiler.DefaultCompiler): """Info compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" def __init__(self, dialect, statement, parameters=None, **kwargs): self.limit = 0 self.offset = 0 - ansisql.ANSICompiler.__init__( self , dialect , statement , parameters , **kwargs ) + compiler.DefaultCompiler.__init__( self , dialect , statement , parameters , **kwargs ) def default_from(self): return " from systables where tabname = 'systables' " @@ -416,7 +405,7 @@ class InfoCompiler(ansisql.ANSICompiler): if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid': select.append_column( c ) - return ansisql.ANSICompiler.visit_select(self, select) + return compiler.DefaultCompiler.visit_select(self, select) def limit_clause(self, select): return "" @@ -437,7 +426,7 @@ class InfoCompiler(ansisql.ANSICompiler): elif func.name.lower() in ( 'current_timestamp' , 'now' ): return "CURRENT YEAR TO SECOND" else: - return ansisql.ANSICompiler.visit_function( self , func ) + return compiler.DefaultCompiler.visit_function( self , func ) def visit_clauselist(self, list): try: @@ -446,7 +435,7 @@ class InfoCompiler(ansisql.ANSICompiler): li = [ c for c in list.clauses ] return ', '.join([s for s in [self.process(c) for c in li] if s is not None]) -class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): +class InfoSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, first_pk=False): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ @@ -507,7 +496,7 @@ class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): return super(InfoSchemaGenerator, self).visit_index(index) -class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class InfoIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") @@ -517,10 +506,14 @@ class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def _requires_quotes(self, value): return False -class InfoSchemaDroper(ansisql.ANSISchemaDropper): +class InfoSchemaDroper(compiler.SchemaDropper): def drop_foreignkey(self, constraint): if constraint.name is not None: super( InfoSchemaDroper , self ).drop_foreignkey( constraint ) dialect = InfoDialect poolclass = pool.SingletonThreadPool +dialect.statement_compiler = InfoCompiler +dialect.schemagenerator = InfoSchemaGenerator +dialect.schemadropper = InfoSchemaDropper +dialect.preparer = InfoIdentifierPreparer diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 619e072d9..0caccca95 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -39,10 +39,10 @@ Known issues / TODO: import datetime, random, warnings, re -from sqlalchemy import sql, schema, ansisql, exceptions -import sqlalchemy.types as sqltypes -from sqlalchemy.engine import default -import operator, sys +from sqlalchemy import util, sql, schema, exceptions +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes class MSNumeric(sqltypes.Numeric): def result_processor(self, dialect): @@ -366,7 +366,7 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): super(MSSQLExecutionContext_pyodbc, self).post_exec() -class MSSQLDialect(ansisql.ANSIDialect): +class MSSQLDialect(default.DefaultDialect): colspecs = { sqltypes.Unicode : MSNVarchar, sqltypes.Integer : MSInteger, @@ -476,21 +476,6 @@ class MSSQLDialect(ansisql.ANSIDialect): def supports_sane_rowcount(self): raise NotImplementedError() - def compiler(self, statement, bindparams, **kwargs): - return MSSQLCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return MSSQLSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return MSSQLSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection, **kwargs): - return MSSQLDefaultRunner(connection, **kwargs) - - def preparer(self): - return MSSQLIdentifierPreparer(self) - def get_default_schema_name(self, connection): return self.schema_name @@ -878,7 +863,7 @@ dialect_mapping = { } -class MSSQLCompiler(ansisql.ANSICompiler): +class MSSQLCompiler(compiler.DefaultCompiler): def __init__(self, dialect, statement, parameters, **kwargs): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) self.tablealiases = {} @@ -931,13 +916,13 @@ class MSSQLCompiler(ansisql.ANSICompiler): def visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq: - return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator)) + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) else: return super(MSSQLCompiler, self).visit_binary(binary) def label_select_column(self, select, column): - if isinstance(column, sql._Function): + if isinstance(column, expression._Function): return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) else: return super(MSSQLCompiler, self).label_select_column(select, column) @@ -963,7 +948,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): return "" -class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): +class MSSQLSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -986,7 +971,7 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec -class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): +class MSSQLSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s.%s" % ( self.preparer.quote_identifier(index.table.name), @@ -995,11 +980,11 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): self.execute() -class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): +class MSSQLDefaultRunner(base.DefaultRunner): # TODO: does ms-sql have standalone sequences ? pass -class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class MSSQLIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') @@ -1012,6 +997,11 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): return value dialect = MSSQLDialect +dialect.statement_compiler = MSSQLCompiler +dialect.schemagenerator = MSSQLSchemaGenerator +dialect.schemadropper = MSSQLSchemaDropper +dialect.preparer = MSSQLIdentifierPreparer +dialect.defaultrunner = MSSQLDefaultRunner diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index d5fd3b6c5..41c6ec70f 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -126,11 +126,12 @@ information affecting MySQL in SQLAlchemy. import re, datetime, inspect, warnings, sys from array import array as _array -from sqlalchemy import ansisql, exceptions, logging, schema, sql, util -from sqlalchemy import operators as sql_operators +from sqlalchemy import exceptions, logging, schema, sql, util +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy.sql import compiler from sqlalchemy.engine import base as engine_base, default -import sqlalchemy.types as sqltypes +from sqlalchemy import types as sqltypes __all__ = ( @@ -1328,13 +1329,17 @@ class MySQLExecutionContext(default.DefaultExecutionContext): return AUTOCOMMIT_RE.match(self.statement) -class MySQLDialect(ansisql.ANSIDialect): +class MySQLDialect(default.DefaultDialect): """Details of the MySQL dialect. Not used directly in application code.""" def __init__(self, use_ansiquotes=False, **kwargs): self.use_ansiquotes = use_ansiquotes kwargs.setdefault('default_paramstyle', 'format') - ansisql.ANSIDialect.__init__(self, **kwargs) + if self.use_ansiquotes: + self.preparer = MySQLANSIIdentifierPreparer + else: + self.preparer = MySQLIdentifierPreparer + default.DefaultDialect.__init__(self, **kwargs) def dbapi(cls): import MySQLdb as mysql @@ -1393,7 +1398,7 @@ class MySQLDialect(ansisql.ANSIDialect): return True def compiler(self, statement, bindparams, **kwargs): - return MySQLCompiler(self, statement, bindparams, **kwargs) + return MySQLCompiler(statement, bindparams, dialect=self, **kwargs) def schemagenerator(self, *args, **kwargs): return MySQLSchemaGenerator(self, *args, **kwargs) @@ -1401,12 +1406,6 @@ class MySQLDialect(ansisql.ANSIDialect): def schemadropper(self, *args, **kwargs): return MySQLSchemaDropper(self, *args, **kwargs) - def preparer(self): - if self.use_ansiquotes: - return MySQLANSIIdentifierPreparer(self) - else: - return MySQLIdentifierPreparer(self) - def do_executemany(self, cursor, statement, parameters, context=None, **kwargs): rowcount = cursor.executemany(statement, parameters) @@ -1733,8 +1732,8 @@ class _MySQLPythonRowProxy(object): return item -class MySQLCompiler(ansisql.ANSICompiler): - operators = ansisql.ANSICompiler.operators.copy() +class MySQLCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() operators.update( { sql_operators.concat_op: \ @@ -1783,7 +1782,7 @@ class MySQLCompiler(ansisql.ANSICompiler): # In older versions, the indexes must be created explicitly or the # creation of foreign key constraints fails." -class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): +class MySQLSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): """Builds column DDL.""" @@ -1827,7 +1826,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): return ' '.join(table_opts) -class MySQLSchemaDropper(ansisql.ANSISchemaDropper): +class MySQLSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s ON %s" % (self.preparer.format_index(index), @@ -2368,7 +2367,7 @@ class MySQLSchemaReflector(object): MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector) -class _MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): """MySQL-specific schema identifier configuration.""" def __init__(self, dialect, **kw): @@ -2433,3 +2432,6 @@ def _re_compile(regex): return re.compile(regex, re.I | re.UNICODE) dialect = MySQLDialect +dialect.statement_compiler = MySQLCompiler +dialect.schemagenerator = MySQLSchemaGenerator +dialect.schemadropper = MySQLSchemaDropper diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index a35db1982..2d8f2940f 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -5,11 +5,13 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, warnings, operator, random +import re, warnings, random -from sqlalchemy import util, sql, schema, ansisql, exceptions, logging +from sqlalchemy import util, sql, schema, exceptions, logging from sqlalchemy.engine import default, base -import sqlalchemy.types as sqltypes +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes import datetime @@ -229,9 +231,9 @@ class OracleExecutionContext(default.DefaultExecutionContext): return base.ResultProxy(self) -class OracleDialect(ansisql.ANSIDialect): +class OracleDialect(default.DefaultDialect): def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs) + default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs) self.use_ansi = use_ansi self.threaded = threaded self.allow_twophase = allow_twophase @@ -333,21 +335,6 @@ class OracleDialect(ansisql.ANSIDialect): def create_execution_context(self, *args, **kwargs): return OracleExecutionContext(self, *args, **kwargs) - def compiler(self, statement, bindparams, **kwargs): - return OracleCompiler(self, statement, bindparams, **kwargs) - - def preparer(self): - return OracleIdentifierPreparer(self) - - def schemagenerator(self, *args, **kwargs): - return OracleSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return OracleSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection, **kwargs): - return OracleDefaultRunner(connection, **kwargs) - def has_table(self, connection, table_name, schema=None): cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':self._denormalize_name(table_name)}) return bool( cursor.fetchone() is not None ) @@ -560,16 +547,16 @@ class _OuterJoinColumn(sql.ClauseElement): def __init__(self, column): self.column = column -class OracleCompiler(ansisql.ANSICompiler): +class OracleCompiler(compiler.DefaultCompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False. """ - operators = ansisql.ANSICompiler.operators.copy() + operators = compiler.DefaultCompiler.operators.copy() operators.update( { - operator.mod : lambda x, y:"mod(%s, %s)" % (x, y) + sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) } ) @@ -590,13 +577,13 @@ class OracleCompiler(ansisql.ANSICompiler): def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return ansisql.ANSICompiler.visit_join(self, join, **kwargs) + return compiler.DefaultCompiler.visit_join(self, join, **kwargs) (where, parentjoin) = self.__wheres.get(join, (None, None)) - class VisitOn(sql.ClauseVisitor): + class VisitOn(visitors.ClauseVisitor): def visit_binary(s, binary): - if binary.operator == operator.eq: + if binary.operator == sql_operators.eq: if binary.left.table is join.right: binary.left = _OuterJoinColumn(binary.left) elif binary.right.table is join.right: @@ -640,7 +627,7 @@ class OracleCompiler(ansisql.ANSICompiler): for c in insert.table.primary_key: if c.key not in self.parameters: self.parameters[c.key] = None - return ansisql.ANSICompiler.visit_insert(self, insert) + return compiler.DefaultCompiler.visit_insert(self, insert) def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" @@ -672,7 +659,7 @@ class OracleCompiler(ansisql.ANSICompiler): limitselect.append_whereclause("ora_rn<=%d" % select._limit) return self.process(limitselect, **kwargs) else: - return ansisql.ANSICompiler.visit_select(self, select, **kwargs) + return compiler.DefaultCompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" @@ -684,7 +671,7 @@ class OracleCompiler(ansisql.ANSICompiler): return super(OracleCompiler, self).for_update_clause(select) -class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): +class OracleSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -701,13 +688,13 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class OracleSchemaDropper(ansisql.ANSISchemaDropper): +class OracleSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class OracleDefaultRunner(ansisql.ANSIDefaultRunner): +class OracleDefaultRunner(base.DefaultRunner): def exec_default_sql(self, default): c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection) return self.connection.execute(c).scalar() @@ -715,10 +702,15 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def visit_sequence(self, seq): return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar() -class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class OracleIdentifierPreparer(compiler.IdentifierPreparer): def format_savepoint(self, savepoint): name = re.sub(r'^_+', '', savepoint.ident) return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) dialect = OracleDialect +dialect.statement_compiler = OracleCompiler +dialect.schemagenerator = OracleSchemaGenerator +dialect.schemadropper = OracleSchemaDropper +dialect.preparer = OracleIdentifierPreparer +dialect.defaultrunner = OracleDefaultRunner diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 74a3ef13f..29d84ad4d 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,11 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, random, warnings, operator +import re, random, warnings -from sqlalchemy import sql, schema, ansisql, exceptions, util +from sqlalchemy import sql, schema, exceptions, util from sqlalchemy.engine import base, default -import sqlalchemy.types as sqltypes +from sqlalchemy.sql import compiler +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes class PGInet(sqltypes.TypeEngine): @@ -220,9 +222,9 @@ class PGExecutionContext(default.DefaultExecutionContext): self._last_inserted_ids = [v for v in row] super(PGExecutionContext, self).post_exec() -class PGDialect(ansisql.ANSIDialect): +class PGDialect(default.DefaultDialect): def __init__(self, use_oids=False, server_side_cursors=False, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) + default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors self.paramstyle = 'pyformat' @@ -249,15 +251,6 @@ class PGDialect(ansisql.ANSIDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def compiler(self, statement, bindparams, **kwargs): - return PGCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return PGSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return PGSchemaDropper(self, *args, **kwargs) - def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) @@ -286,12 +279,6 @@ class PGDialect(ansisql.ANSIDialect): resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] - def defaultrunner(self, context, **kwargs): - return PGDefaultRunner(context, **kwargs) - - def preparer(self): - return PGIdentifierPreparer(self) - def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): self._default_schema_name = connection.scalar("select current_schema()", None) @@ -556,11 +543,11 @@ class PGDialect(ansisql.ANSIDialect): -class PGCompiler(ansisql.ANSICompiler): - operators = ansisql.ANSICompiler.operators.copy() +class PGCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() operators.update( { - operator.mod : '%%' + sql_operators.mod : '%%' } ) @@ -597,7 +584,7 @@ class PGCompiler(ansisql.ANSICompiler): else: return super(PGCompiler, self).for_update_clause(select) -class PGSchemaGenerator(ansisql.ANSISchemaGenerator): +class PGSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): @@ -620,13 +607,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGSchemaDropper(ansisql.ANSISchemaDropper): +class PGSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGDefaultRunner(ansisql.ANSIDefaultRunner): +class PGDefaultRunner(base.DefaultRunner): def get_column_default(self, column, isinsert=True): if column.primary_key: # passive defaults on primary keys have to be overridden @@ -642,7 +629,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) return self.connection.execute(exc).scalar() - return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) + return super(PGDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: @@ -650,7 +637,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): else: return None -class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class PGIdentifierPreparer(compiler.IdentifierPreparer): def _fold_identifier_case(self, value): return value.lower() @@ -660,3 +647,8 @@ class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): return value dialect = PGDialect +dialect.statement_compiler = PGCompiler +dialect.schemagenerator = PGSchemaGenerator +dialect.schemadropper = PGSchemaDropper +dialect.preparer = PGIdentifierPreparer +dialect.defaultrunner = PGDefaultRunner diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index d96422236..c2aced4d0 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -7,11 +7,12 @@ import re -from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault -import sqlalchemy.engine.default as default +from sqlalchemy import schema, exceptions, pool, PassiveDefault +from sqlalchemy.engine import default import sqlalchemy.types as sqltypes import datetime,time, warnings import sqlalchemy.util as util +from sqlalchemy.sql import compiler SELECT_REGEXP = re.compile(r'\s*(?:SELECT|PRAGMA)', re.I | re.UNICODE) @@ -172,10 +173,10 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): def is_select(self): return SELECT_REGEXP.match(self.statement) -class SQLiteDialect(ansisql.ANSIDialect): +class SQLiteDialect(default.DefaultDialect): def __init__(self, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs) + default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs) def vers(num): return tuple([int(x) for x in num.split('.')]) if self.dbapi is not None: @@ -195,24 +196,12 @@ class SQLiteDialect(ansisql.ANSIDialect): return sqlite dbapi = classmethod(dbapi) - def compiler(self, statement, bindparams, **kwargs): - return SQLiteCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return SQLiteSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return SQLiteSchemaDropper(self, *args, **kwargs) - def server_version_info(self, connection): return self.dbapi.sqlite_version_info def supports_alter(self): return False - def preparer(self): - return SQLiteIdentifierPreparer(self) - def create_connect_args(self, url): filename = url.database or ':memory:' @@ -255,7 +244,7 @@ class SQLiteDialect(ansisql.ANSIDialect): return (row is not None) def reflecttable(self, connection, table, include_columns): - c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("PRAGMA table_info(%s)" % self.identifier_preparer.format_table(table), {}) found_table = False while True: row = c.fetchone() @@ -295,7 +284,7 @@ class SQLiteDialect(ansisql.ANSIDialect): if not found_table: raise exceptions.NoSuchTableError(table.name) - c = connection.execute("PRAGMA foreign_key_list(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("PRAGMA foreign_key_list(%s)" % self.identifier_preparer.format_table(table), {}) fks = {} while True: row = c.fetchone() @@ -324,7 +313,7 @@ class SQLiteDialect(ansisql.ANSIDialect): for name, value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1])) # check for UNIQUE indexes - c = connection.execute("PRAGMA index_list(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("PRAGMA index_list(%s)" % self.identifier_preparer.format_table(table), {}) unique_indexes = [] while True: row = c.fetchone() @@ -343,7 +332,7 @@ class SQLiteDialect(ansisql.ANSIDialect): cols.append(row[2]) col = table.columns[row[2]] -class SQLiteCompiler(ansisql.ANSICompiler): +class SQLiteCompiler(compiler.DefaultCompiler): def visit_cast(self, cast): if self.dialect.supports_cast: return super(SQLiteCompiler, self).visit_cast(cast) @@ -369,7 +358,8 @@ class SQLiteCompiler(ansisql.ANSICompiler): # sqlite has no "FOR UPDATE" AFAICT return '' -class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): + +class SQLiteSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -391,12 +381,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): # else: # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) -class SQLiteSchemaDropper(ansisql.ANSISchemaDropper): +class SQLiteSchemaDropper(compiler.SchemaDropper): pass -class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True) dialect = SQLiteDialect dialect.poolclass = pool.SingletonThreadPool +dialect.statement_compiler = SQLiteCompiler +dialect.schemagenerator = SQLiteSchemaGenerator +dialect.schemadropper = SQLiteSchemaDropper +dialect.preparer = SQLiteIdentifierPreparer + diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index faeb00cc9..553c8df84 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -8,7 +8,8 @@ higher-level statement-construction, connection-management, execution and result contexts.""" -from sqlalchemy import exceptions, sql, schema, util, types, logging +from sqlalchemy import exceptions, schema, util, types, logging +from sqlalchemy.sql import expression, visitors import StringIO, sys @@ -35,6 +36,22 @@ class Dialect(object): encoding type of encoding to use for unicode, usually defaults to 'utf-8' + + schemagenerator + a [sqlalchemy.schema#SchemaVisitor] class which generates schemas. + + schemadropper + a [sqlalchemy.schema#SchemaVisitor] class which drops schemas. + + defaultrunner + a [sqlalchemy.schema#SchemaVisitor] class which executes defaults. + + statement_compiler + a [sqlalchemy.engine.base#Compiled] class used to compile SQL statements + + preparer + a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote + identifiers. """ def create_connect_args(self, url): @@ -105,48 +122,6 @@ class Dialect(object): raise NotImplementedError() - def schemagenerator(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas. - - connection - a [sqlalchemy.engine#Connection] to use for statement execution - - `schemagenerator()` is called via the `create()` method on Table, - Index, and others. - """ - - raise NotImplementedError() - - def schemadropper(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas. - - connection - a [sqlalchemy.engine#Connection] to use for statement execution - - `schemadropper()` is called via the `drop()` method on Table, - Index, and others. - """ - - raise NotImplementedError() - - def defaultrunner(self, execution_context): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults. - - execution_context - a [sqlalchemy.engine#ExecutionContext] to use for statement execution - - """ - - raise NotImplementedError() - - def compiler(self, statement, parameters): - """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters. - - The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler]. - - """ - - raise NotImplementedError() def server_version_info(self, connection): """Return a tuple of the database's version number.""" @@ -266,16 +241,6 @@ class Dialect(object): raise NotImplementedError() - - def compile(self, clauseelement, parameters=None): - """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect. - - Returns [sqlalchemy.sql#Compiled]. A convenience method which - flips around the compile() call on ``ClauseElement``. - """ - - return clauseelement.compile(dialect=self, parameters=parameters) - def is_disconnect(self, e): """Return True if the given DBAPI error indicates an invalid connection""" @@ -304,7 +269,7 @@ class ExecutionContext(object): DBAPI cursor procured from the connection compiled - if passed to constructor, sql.Compiled object being executed + if passed to constructor, sqlalchemy.engine.base.Compiled object being executed statement string version of the statement to be executed. Is either @@ -439,6 +404,9 @@ class Compiled(object): def __init__(self, dialect, statement, parameters, bind=None): """Construct a new ``Compiled`` object. + dialect + ``Dialect`` to compile against. + statement ``ClauseElement`` to be compiled. @@ -724,8 +692,8 @@ class Connection(Connectable): def scalar(self, object, *multiparams, **params): return self.execute(object, *multiparams, **params).scalar() - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, parameters, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) def execute(self, object, *multiparams, **params): for c in type(object).__mro__: @@ -822,9 +790,9 @@ class Connection(Connectable): # poor man's multimethod/generic function thingy executors = { - sql._Function : _execute_function, - sql.ClauseElement : _execute_clauseelement, - sql.ClauseVisitor : _execute_compiled, + expression._Function : _execute_function, + expression.ClauseElement : _execute_clauseelement, + visitors.ClauseVisitor : _execute_compiled, schema.SchemaItem:_execute_default, str.__mro__[-2] : _execute_text } @@ -989,14 +957,14 @@ class Engine(Connectable): connection.close() def _func(self): - return sql._FunctionGenerator(bind=self) + return expression._FunctionGenerator(bind=self) func = property(_func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" - return sql.text(text, bind=self, *args, **kwargs) + return expression.text(text, bind=self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: @@ -1004,7 +972,7 @@ class Engine(Connectable): else: conn = connection try: - visitorcallable(conn, **kwargs).traverse(element) + visitorcallable(self.dialect, conn, **kwargs).traverse(element) finally: if connection is None: conn.close() @@ -1057,8 +1025,8 @@ class Engine(Connectable): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, parameters, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1159,6 +1127,7 @@ class ResultProxy(object): self.closed = False self.cursor = context.cursor self.__echo = logging.is_debug_enabled(context.engine.logger) + self._process_row = self._row_processor() if context.is_select(): self._init_metadata() self._rowcount = None @@ -1222,7 +1191,7 @@ class ResultProxy(object): rec = props[key] elif isinstance(key, basestring) and key.lower() in props: rec = props[key.lower()] - elif isinstance(key, sql.ColumnElement): + elif isinstance(key, expression.ColumnElement): label = context.column_labels.get(key._label, key.name).lower() if label in props: rec = props[label] @@ -1320,21 +1289,21 @@ class ResultProxy(object): return self.cursor.fetchmany(size) def _fetchall_impl(self): return self.cursor.fetchall() + + def _row_processor(self): + return RowProxy - def _process_row(self, row): - return RowProxy(self, row) - def fetchall(self): """Fetch all rows, just like DBAPI ``cursor.fetchall()``.""" - l = [self._process_row(row) for row in self._fetchall_impl()] + l = [self._process_row(self, row) for row in self._fetchall_impl()] self.close() return l def fetchmany(self, size=None): """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``.""" - l = [self._process_row(row) for row in self._fetchmany_impl(size)] + l = [self._process_row(self, row) for row in self._fetchmany_impl(size)] if len(l) == 0: self.close() return l @@ -1343,7 +1312,7 @@ class ResultProxy(object): """Fetch one row, just like DBAPI ``cursor.fetchone()``.""" row = self._fetchone_impl() if row is not None: - return self._process_row(row) + return self._process_row(self, row) else: self.close() return None @@ -1353,7 +1322,7 @@ class ResultProxy(object): row = self._fetchone_impl() try: if row is not None: - return self._process_row(row)[0] + return self._process_row(self, row)[0] else: return None finally: @@ -1425,11 +1394,9 @@ class BufferedColumnResultProxy(ResultProxy): def _get_col(self, row, key): rec = self._key_cache[key] return row[rec[2]] - - def _process_row(self, row): - sup = super(BufferedColumnResultProxy, self) - row = [sup._get_col(row, i) for i in xrange(len(row))] - return RowProxy(self, row) + + def _row_processor(self): + return BufferedColumnRow def fetchall(self): l = [] @@ -1523,6 +1490,11 @@ class RowProxy(object): def __len__(self): return len(self.__row) +class BufferedColumnRow(RowProxy): + def __init__(self, parent, row): + row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))] + super(BufferedColumnRow, self).__init__(parent, row) + class SchemaIterator(schema.SchemaVisitor): """A visitor that can gather text into a buffer and execute the contents of the buffer.""" @@ -1590,11 +1562,11 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg]).compile(bind=self.connection) + c = expression.select([default.arg]).compile(bind=self.connection) return self.connection._execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): - if isinstance(onupdate.arg, sql.ClauseElement): + if isinstance(onupdate.arg, expression.ClauseElement): return self.exec_default_sql(onupdate) elif callable(onupdate.arg): return onupdate.arg(self.context) @@ -1602,7 +1574,7 @@ class DefaultRunner(schema.SchemaVisitor): return onupdate.arg def visit_column_default(self, default): - if isinstance(default.arg, sql.ClauseElement): + if isinstance(default.arg, expression.ClauseElement): return self.exec_default_sql(default) elif callable(default.arg): return default.arg(self.context) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ccaf080e7..059395921 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -9,7 +9,7 @@ from sqlalchemy import schema, exceptions, sql, util import re, random from sqlalchemy.engine import base - +from sqlalchemy.sql import compiler, expression AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) @@ -18,6 +18,12 @@ SELECT_REGEXP = re.compile(r'\s*SELECT', re.I | re.UNICODE) class DefaultDialect(base.Dialect): """Default implementation of Dialect""" + schemagenerator = compiler.SchemaGenerator + schemadropper = compiler.SchemaDropper + statement_compiler = compiler.DefaultCompiler + preparer = compiler.IdentifierPreparer + defaultrunner = base.DefaultRunner + def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode self.encoding = encoding @@ -25,6 +31,7 @@ class DefaultDialect(base.Dialect): self._ischema = None self.dbapi = dbapi self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) + self.identifier_preparer = self.preparer(self) def dbapi_type_map(self): # most DBAPIs have problems with this (such as, psycocpg2 types @@ -46,6 +53,7 @@ class DefaultDialect(base.Dialect): typeobj = typeobj() return typeobj + def supports_unicode_statements(self): """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" return False @@ -96,13 +104,13 @@ class DefaultDialect(base.Dialect): return "_sa_%032x" % random.randint(0,2**128) def do_savepoint(self, connection, name): - connection.execute(sql.SavepointClause(name)) + connection.execute(expression.SavepointClause(name)) def do_rollback_to_savepoint(self, connection, name): - connection.execute(sql.RollbackToSavepointClause(name)) + connection.execute(expression.RollbackToSavepointClause(name)) def do_release_savepoint(self, connection, name): - connection.execute(sql.ReleaseSavepointClause(name)) + connection.execute(expression.ReleaseSavepointClause(name)) def do_executemany(self, cursor, statement, parameters, **kwargs): cursor.executemany(statement, parameters) @@ -110,8 +118,6 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters) - def defaultrunner(self, context): - return base.DefaultRunner(context) def is_disconnect(self, e): return False diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 65d2ab3d2..258bddb4a 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -294,7 +294,7 @@ from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext from sqlalchemy.exceptions import * - +from sqlalchemy.sql import expression _testsql = """ CREATE TABLE books ( @@ -415,7 +415,7 @@ def _selectable_name(selectable): return x def class_for_table(selectable, **mapper_kwargs): - selectable = sql._selectable(selectable) + selectable = expression._selectable(selectable) mapname = 'Mapped' + _selectable_name(selectable) if isinstance(selectable, Table): klass = TableClassType(mapname, (object,), {}) @@ -499,7 +499,7 @@ class SqlSoup: def with_labels(self, item): # TODO give meaningful aliases - return self.map(sql._selectable(item).select(use_labels=True).alias('foo')) + return self.map(expression._selectable(item).select(use_labels=True).alias('foo')) def join(self, *args, **kwargs): j = join(*args, **kwargs) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c54eee438..9000a8df5 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -6,6 +6,7 @@ from sqlalchemy import util, logging, sql +from sqlalchemy.sql import expression __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension', 'MapperProperty', 'PropComparator', 'StrategizedProperty', @@ -363,7 +364,7 @@ class MapperProperty(object): return operator(self.comparator, value) -class PropComparator(sql.ColumnOperators): +class PropComparator(expression.ColumnOperators): """defines comparison operations for MapperProperty objects""" def expression_element(self): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b48368417..60d4526ec 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -6,7 +6,8 @@ import weakref, warnings, operator from sqlalchemy import sql, util, exceptions, logging -from sqlalchemy import sql_util as sqlutil +from sqlalchemy.sql import expression +from sqlalchemy.sql import util as sqlutil from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.util import ExtensionCarrier from sqlalchemy.orm import sync @@ -77,7 +78,7 @@ class Mapper(object): raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) for table in (local_table, select_table): - if table is not None and isinstance(table, sql._SelectBaseMixin): + if table is not None and isinstance(table, expression._SelectBaseMixin): # some db's, noteably postgres, dont want to select from a select # without an alias. also if we make our own alias internally, then # the configured properties on the mapper are not matched against the alias @@ -438,7 +439,7 @@ class Mapper(object): # against the "mapped_table" of this mapper. equivalent_columns = self._get_equivalent_columns() - primary_key = sql.ColumnSet() + primary_key = expression.ColumnSet() for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): c = self.mapped_table.corresponding_column(col, raiseerr=False) @@ -644,9 +645,9 @@ class Mapper(object): props = {} if self.properties is not None: for key, prop in self.properties.iteritems(): - if sql.is_column(prop): + if expression.is_column(prop): props[key] = self.select_table.corresponding_column(prop) - elif (isinstance(prop, list) and sql.is_column(prop[0])): + elif (isinstance(prop, list) and expression.is_column(prop[0])): props[key] = [self.select_table.corresponding_column(c) for c in prop] self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument) @@ -768,7 +769,7 @@ class Mapper(object): def _create_prop_from_column(self, column): column = util.to_list(column) - if not sql.is_column(column[0]): + if not expression.is_column(column[0]): return None mapped_column = [] for c in column: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 670fcccc9..20cbcb235 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -11,7 +11,8 @@ operations. PropertyLoader also relies upon the dependency.py module to handle flush-time dependency sorting and processing. """ -from sqlalchemy import sql, schema, util, exceptions, sql_util, logging +from sqlalchemy import sql, schema, util, exceptions, logging +from sqlalchemy.sql import util as sql_util from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 44329468f..5cbe19ce2 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -4,7 +4,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy import sql, util, exceptions, logging +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import expression, visitors from sqlalchemy.orm import mapper, object_mapper from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.interfaces import OperationContext, LoaderStack @@ -312,7 +314,7 @@ class Query(object): clause = self._from_obj[-1] currenttables = [clause] - class FindJoinedTables(sql.NoColumnVisitor): + class FindJoinedTables(visitors.NoColumnVisitor): def visit_join(self, join): currenttables.append(join.left) currenttables.append(join.right) @@ -836,7 +838,7 @@ class Query(object): # if theres an order by, add those columns to the column list # of the "rowcount" query we're going to make if order_by: - order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []] + order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] cf = sql_util.ColumnFinder() for o in order_by: cf.traverse(o) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 6565c8d77..bdb17e1d6 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -6,7 +6,9 @@ """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" -from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy import sql, util, exceptions, logging +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import visitors from sqlalchemy.orm import mapper, attributes from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption from sqlalchemy.orm import session as sessionlib @@ -292,7 +294,7 @@ class LazyLoader(AbstractRelationLoader): (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) - class Visitor(sql.ClauseVisitor): + class Visitor(visitors.ClauseVisitor): def visit_bindparam(s, bindparam): mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent if bindparam.key in bind_to_col: @@ -396,7 +398,7 @@ class LazyLoader(AbstractRelationLoader): if not isinstance(expr, sql.ColumnElement): return None columns = [] - class FindColumnInColumnClause(sql.ClauseVisitor): + class FindColumnInColumnClause(visitors.ClauseVisitor): def visit_column(self, c): columns.append(c) FindColumnInColumnClause().traverse(expr) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index cf48202b0..49661a95e 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -10,9 +10,9 @@ clause that compares column values. """ from sqlalchemy import sql, schema, exceptions +from sqlalchemy.sql import visitors, operators from sqlalchemy import logging from sqlalchemy.orm import util as mapperutil -import operator ONETOMANY = 0 MANYTOONE = 1 @@ -43,7 +43,7 @@ class ClauseSynchronizer(object): def compile_binary(binary): """Assemble a SyncRule given a single binary condition.""" - if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return source_column = None @@ -144,7 +144,7 @@ class SyncRule(object): SyncRule.logger = logging.class_logger(SyncRule) -class BinaryVisitor(sql.ClauseVisitor): +class BinaryVisitor(visitors.ClauseVisitor): def __init__(self, func): self.func = func diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 6b0956dc4..b3f58c954 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -4,7 +4,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util +from sqlalchemy import sql, util, exceptions +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import visitors from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", @@ -161,7 +163,7 @@ class ExtensionCarrier(MapperExtension): before_delete = _create_do('before_delete') after_delete = _create_do('after_delete') -class BinaryVisitor(sql.ClauseVisitor): +class BinaryVisitor(visitors.ClauseVisitor): def __init__(self, func): self.func = func @@ -196,7 +198,7 @@ class AliasedClauses(object): # for column-level subqueries, swap out its selectable with our # eager version as appropriate, and manually build the # "correlation" list of the subquery. - class ModifySubquery(sql.ClauseVisitor): + class ModifySubquery(visitors.ClauseVisitor): def visit_select(s, select): select._should_correlate = False select.append_correlation(self.alias) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 99803d665..99ca2389b 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -18,8 +18,11 @@ objects as well as the visitor interface, so that the schema package """ import re, inspect -from sqlalchemy import sql, types, exceptions, util, databases +from sqlalchemy import types, exceptions, util, databases +from sqlalchemy.sql import expression, visitors import sqlalchemy + + URL = None __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', @@ -31,7 +34,7 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', class SchemaItem(object): """Base class for items that define a database schema.""" - __metaclass__ = sql._FigureVisitName + __metaclass__ = expression._FigureVisitName def _init_items(self, *args): """Initialize the list of child items for this SchemaItem.""" @@ -84,7 +87,7 @@ def _get_table_key(name, schema): else: return schema + "." + name -class _TableSingleton(sql._FigureVisitName): +class _TableSingleton(expression._FigureVisitName): """A metaclass used by the ``Table`` object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): @@ -124,10 +127,10 @@ class _TableSingleton(sql._FigureVisitName): return table -class Table(SchemaItem, sql.TableClause): +class Table(SchemaItem, expression.TableClause): """Represent a relational database table. - This subclasses ``sql.TableClause`` to provide a table that is + This subclasses ``expression.TableClause`` to provide a table that is associated with an instance of ``MetaData``, which in turn may be associated with an instance of ``Engine``. @@ -229,7 +232,7 @@ class Table(SchemaItem, sql.TableClause): self.schema = kwargs.pop('schema', None) self.indexes = util.Set() self.constraints = util.Set() - self._columns = sql.ColumnCollection() + self._columns = expression.ColumnCollection() self.primary_key = PrimaryKeyConstraint() self._foreign_keys = util.OrderedSet() self.quote = kwargs.pop('quote', False) @@ -291,7 +294,7 @@ class Table(SchemaItem, sql.TableClause): def get_children(self, column_collections=True, schema_visitor=False, **kwargs): if not schema_visitor: - return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs) + return expression.TableClause.get_children(self, column_collections=column_collections, **kwargs) else: if column_collections: return [c for c in self.columns] @@ -338,10 +341,10 @@ class Table(SchemaItem, sql.TableClause): args.append(c.copy()) return Table(self.name, metadata, schema=schema, *args) -class Column(SchemaItem, sql._ColumnClause): +class Column(SchemaItem, expression._ColumnClause): """Represent a column in a database table. - This is a subclass of ``sql.ColumnClause`` and represents an + This is a subclass of ``expression.ColumnClause`` and represents an actual existing table in the database, in a similar fashion as ``TableClause``/``Table``. """ @@ -575,7 +578,7 @@ class Column(SchemaItem, sql._ColumnClause): return [x for x in (self.default, self.onupdate) if x is not None] + \ list(self.foreign_keys) + list(self.constraints) else: - return sql._ColumnClause.get_children(self, **kwargs) + return expression._ColumnClause.get_children(self, **kwargs) class ForeignKey(SchemaItem): @@ -806,7 +809,7 @@ class Constraint(SchemaItem): def __init__(self, name=None): self.name = name - self.columns = sql.ColumnCollection() + self.columns = expression.ColumnCollection() def __contains__(self, x): return self.columns.contains_column(x) @@ -1124,12 +1127,12 @@ class MetaData(SchemaItem): del self.tables[table.key] def table_iterator(self, reverse=True, tables=None): - import sqlalchemy.sql_util + from sqlalchemy.sql import util as sql_util if tables is None: tables = self.tables.values() else: tables = util.Set(tables).intersection(self.tables.values()) - sorter = sqlalchemy.sql_util.TableCollection(list(tables)) + sorter = sql_util.TableCollection(list(tables)) return iter(sorter.sort(reverse=reverse)) def _get_parent(self): @@ -1356,7 +1359,7 @@ class ThreadLocalMetaData(MetaData): e.dispose() -class SchemaVisitor(sql.ClauseVisitor): +class SchemaVisitor(visitors.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" __traverse_options__ = {'schema_visitor':True} diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py new file mode 100644 index 000000000..06b9f1f75 --- /dev/null +++ b/lib/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.sql.expression import * +from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor + diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/sql/compiler.py index 5f5e1c171..6053c72be 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1,22 +1,18 @@ -# ansisql.py +# compiler.py # Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Defines ANSI SQL operations. +"""SQL expression compilation routines and DDL implementations.""" -Contains default implementations for the abstract objects in the sql -module. -""" +import string, re +from sqlalchemy import schema, engine, util, exceptions +from sqlalchemy.sql import operators, visitors +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import expression as sql -import string, re, sets, operator - -from sqlalchemy import schema, sql, engine, util, exceptions, operators -from sqlalchemy.engine import default - - -ANSI_FUNCS = sets.ImmutableSet([ +ANSI_FUNCS = util.Set([ 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', 'SESSION_USER', 'USER']) @@ -77,7 +73,6 @@ OPERATORS = { operators.comma_op : ', ', operators.desc_op : 'DESC', operators.asc_op : 'ASC', - operators.from_ : 'FROM', operators.as_ : 'AS', operators.exists : 'EXISTS', @@ -85,36 +80,10 @@ OPERATORS = { operators.isnot : 'IS NOT' } -class ANSIDialect(default.DefaultDialect): - def __init__(self, cache_identifiers=True, **kwargs): - super(ANSIDialect,self).__init__(**kwargs) - self.identifier_preparer = self.preparer() - self.cache_identifiers = cache_identifiers - - def create_connect_args(self): - return ([],{}) - - def schemagenerator(self, *args, **kwargs): - return ANSISchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return ANSISchemaDropper(self, *args, **kwargs) - - def compiler(self, statement, parameters, **kwargs): - return ANSICompiler(self, statement, parameters, **kwargs) - - def preparer(self): - """Return an IdentifierPreparer. - - This object is used to format table and column names including - proper quoting and case conventions. - """ - return ANSIIdentifierPreparer(self) - -class ANSICompiler(engine.Compiled, sql.ClauseVisitor): +class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): """Default implementation of Compiled. - Compiles ClauseElements into ANSI-compliant SQL strings. + Compiles ClauseElements into SQL strings. """ __traverse_options__ = {'column_collections':False, 'entry':True} @@ -122,7 +91,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): operators = OPERATORS def __init__(self, dialect, statement, parameters=None, **kwargs): - """Construct a new ``ANSICompiler`` object. + """Construct a new ``DefaultCompiler`` object. dialect Dialect to be used @@ -139,7 +108,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): correspond to the keys present in the parameters. """ - super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs) + super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs) # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False @@ -170,17 +139,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): self.bindtemplate = ":%s" # paramstyle from the dialect (comes from DBAPI) - self.paramstyle = dialect.paramstyle + self.paramstyle = self.dialect.paramstyle # true if the paramstyle is positional - self.positional = dialect.positional + self.positional = self.dialect.positional # a list of the compiled's bind parameter names, used to help # formulate a positional argument list self.positiontup = [] - # an ANSIIdentifierPreparer that formats the quoting of identifiers - self.preparer = dialect.identifier_preparer + # an IdentifierPreparer that formats the quoting of identifiers + self.preparer = self.dialect.identifier_preparer # for UPDATE and INSERT statements, a set of columns whos values are being set # from a SQL expression (i.e., not one of the bind parameter values). if present, @@ -244,7 +213,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return None def construct_params(self, params): - """Return a sql.ClauseParameters object. + """Return a sql.util.ClauseParameters object. Combines the given bind parameter dictionary (string keys to object values) with the _BindParamClause objects stored within this Compiled object @@ -252,7 +221,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): for a single statement execution, or one element of an executemany execution. """ - d = sql.ClauseParameters(self.dialect, self.positiontup) + d = sql_util.ClauseParameters(self.dialect, self.positiontup) pd = self.parameters or {} pd.update(params) @@ -781,7 +750,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): def __str__(self): return self.string -class ANSISchemaBase(engine.SchemaIterator): +class DDLBase(engine.SchemaIterator): def find_alterables(self, tables): alterables = [] class FindAlterables(schema.SchemaVisitor): @@ -794,12 +763,12 @@ class ANSISchemaBase(engine.SchemaIterator): findalterables.traverse(c) return alterables -class ANSISchemaGenerator(ANSISchemaBase): +class SchemaGenerator(DDLBase): def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaGenerator, self).__init__(connection, **kwargs) + super(SchemaGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables and util.Set(tables) or None - self.preparer = dialect.preparer() + self.preparer = dialect.identifier_preparer self.dialect = dialect def get_column_specification(self, column, first_pk=False): @@ -860,7 +829,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def _compile(self, tocompile, parameters): """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.dialect.compiler(tocompile, parameters) + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) compiler.compile() return compiler @@ -930,12 +899,12 @@ class ANSISchemaGenerator(ANSISchemaBase): string.join([preparer.format_column(c) for c in index.columns], ', '))) self.execute() -class ANSISchemaDropper(ANSISchemaBase): +class SchemaDropper(DDLBase): def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaDropper, self).__init__(connection, **kwargs) + super(SchemaDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables - self.preparer = dialect.preparer() + self.preparer = dialect.identifier_preparer self.dialect = dialect def visit_metadata(self, metadata): @@ -964,14 +933,11 @@ class ANSISchemaDropper(ANSISchemaBase): self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() -class ANSIDefaultRunner(engine.DefaultRunner): - pass - -class ANSIIdentifierPreparer(object): +class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): - """Construct a new ``ANSIIdentifierPreparer`` object. + """Construct a new ``IdentifierPreparer`` object. initial_quote Character that begins a delimited identifier. @@ -1049,20 +1015,14 @@ class ANSIIdentifierPreparer(object): def __generic_obj_format(self, obj, ident): if getattr(obj, 'quote', False): return self.quote_identifier(ident) - if self.dialect.cache_identifiers: - try: - return self.__strings[ident] - except KeyError: - if self._requires_quotes(ident): - self.__strings[ident] = self.quote_identifier(ident) - else: - self.__strings[ident] = ident - return self.__strings[ident] - else: + try: + return self.__strings[ident] + except KeyError: if self._requires_quotes(ident): - return self.quote_identifier(ident) + self.__strings[ident] = self.quote_identifier(ident) else: - return ident + self.__strings[ident] = ident + return self.__strings[ident] def should_quote(self, object): return object.quote or self._requires_quotes(object.name) @@ -1152,5 +1112,3 @@ class ANSIIdentifierPreparer(object): return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] - -dialect = ANSIDialect diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql/expression.py index 7c73f7cb7..e117e3f47 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,13 +26,14 @@ classes usually have few or no public methods and are less guaranteed to stay the same in future releases. """ -from sqlalchemy import util, exceptions, operators +from sqlalchemy import util, exceptions +from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes import re __all__ = [ - 'Alias', 'ClauseElement', 'ClauseParameters', - 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', + 'Alias', 'ClauseElement', + 'ColumnCollection', 'ColumnElement', 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between', 'bindparam', 'case', 'cast', 'column', 'delete', @@ -810,187 +811,6 @@ def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) -class ClauseParameters(object): - """Represent a dictionary/iterator of bind parameter key names/values. - - Tracks the original [sqlalchemy.sql#_BindParamClause] objects - as well as the keys/position of each parameter, and can return - parameters as a dictionary or a list. Will process parameter - values according to the ``TypeEngine`` objects present in the - ``_BindParamClause`` instances. - """ - - def __init__(self, dialect, positional=None): - self.dialect = dialect - self.__binds = {} - self.positional = positional or [] - - def get_parameter(self, key): - return self.__binds[key] - - def set_parameter(self, bindparam, value, name): - self.__binds[name] = [bindparam, name, value] - - def get_original(self, key): - return self.__binds[key][2] - - def get_type(self, key): - return self.__binds[key][0].type - - def get_processors(self): - """return a dictionary of bind 'processing' functions""" - return dict([ - (key, value) for key, value in - [( - key, - self.__binds[key][0].bind_processor(self.dialect) - ) for key in self.__binds] - if value is not None - ]) - - def get_processed(self, key, processors): - return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2] - - def keys(self): - return self.__binds.keys() - - def __iter__(self): - return iter(self.keys()) - - def __getitem__(self, key): - (bind, name, value) = self.__binds[key] - processor = bind.bind_processor(self.dialect) - return processor is not None and processor(value) or value - - def __contains__(self, key): - return key in self.__binds - - def set_value(self, key, value): - self.__binds[key][2] = value - - def get_original_dict(self): - return dict([(name, value) for (b, name, value) in self.__binds.values()]) - - def __get_processed(self, key, processors): - if key in processors: - return processors[key](self.__binds[key][2]) - else: - return self.__binds[key][2] - - def get_raw_list(self, processors): - return [self.__get_processed(key, processors) for key in self.positional] - - def get_raw_dict(self, processors, encode_keys=False): - if encode_keys: - return dict([ - ( - key.encode(self.dialect.encoding), - self.__get_processed(key, processors) - ) - for key in self.keys() - ]) - else: - return dict([ - ( - key, - self.__get_processed(key, processors) - ) - for key in self.keys() - ]) - - def __repr__(self): - return self.__class__.__name__ + ":" + repr(self.get_original_dict()) - -class ClauseVisitor(object): - """A class that knows how to traverse and visit ``ClauseElements``. - - Calls visit_XXX() methods dynamically generated for each - particualr ``ClauseElement`` subclass encountered. Traversal of a - hierarchy of ``ClauseElements`` is achieved via the ``traverse()`` - method, which is passed the lead ``ClauseElement``. - - By default, ``ClauseVisitor`` traverses all elements fully. - Options can be specified at the class level via the - ``__traverse_options__`` dictionary which will be passed to the - ``get_children()`` method of each ``ClauseElement``; these options - can indicate modifications to the set of elements returned, such - as to not return column collections (column_collections=False) or - to return Schema-level items (schema_visitor=True). - - ``ClauseVisitor`` also supports a simultaneous copy-and-traverse - operation, which will produce a copy of a given ``ClauseElement`` - structure while at the same time allowing ``ClauseVisitor`` - subclasses to modify the new structure in-place. - """ - - __traverse_options__ = {} - - def traverse_single(self, obj, **kwargs): - meth = getattr(self, "visit_%s" % obj.__visit_name__, None) - if meth: - return meth(obj, **kwargs) - - def iterate(self, obj, stop_on=None): - stack = [obj] - traversal = [] - while len(stack) > 0: - t = stack.pop() - if stop_on is None or t not in stop_on: - yield t - traversal.insert(0, t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - - def traverse(self, obj, stop_on=None, clone=False): - if clone: - obj = obj._clone() - - stack = [obj] - traversal = [] - while len(stack) > 0: - t = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, t) - if clone: - t._copy_internals() - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - for target in traversal: - v = self - while v is not None: - meth = getattr(v, "visit_%s" % target.__visit_name__, None) - if meth: - meth(target) - v = getattr(v, '_next', None) - return obj - - def chain(self, visitor): - """'chain' an additional ClauseVisitor onto this ClauseVisitor. - - The chained visitor will receive all visit events after this one. - """ - - tail = self - while getattr(tail, '_next', None) is not None: - tail = tail._next - tail._next = visitor - return self - -class NoColumnVisitor(ClauseVisitor): - """A ClauseVisitor that will not traverse exported column collections. - - Will not traverse the exported Column collections on Table, Alias, - Select, and CompoundSelect objects (i.e. their 'columns' or 'c' - attribute). - - This is useful because most traversals don't need those columns, - or in the case of ANSICompiler it traverses them explicitly; so - skipping their traversal here greatly cuts down on method call - overhead. - """ - - __traverse_options__ = {'column_collections': False} - class _FigureVisitName(type): def __init__(cls, clsname, bases, dict): @@ -1061,7 +881,7 @@ class ClauseElement(object): elif len(optionaldict) > 1: raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument") - class Vis(ClauseVisitor): + class Vis(visitors.ClauseVisitor): def visit_bindparam(self, bind): if bind.key in kwargs: bind.value = kwargs[bind.key] @@ -1156,7 +976,7 @@ class ClauseElement(object): if any. Finally, if there is no bound ``Engine``, uses an - ``ANSIDialect`` to create a default ``Compiler``. + ``DefaultDialect`` to create a default ``Compiler``. `parameters` is a dictionary representing the default bind parameters to be used with the statement. If `parameters` is @@ -1175,15 +995,16 @@ class ClauseElement(object): if compiler is None: if dialect is not None: - compiler = dialect.compiler(self, parameters) + compiler = dialect.statement_compiler(dialect, self, parameters) elif bind is not None: - compiler = bind.compiler(self, parameters) + compiler = bind.statement_compiler(self, parameters) elif self.bind is not None: - compiler = self.bind.compiler(self, parameters) + compiler = self.bind.statement_compiler(self, parameters) if compiler is None: - import sqlalchemy.ansisql as ansisql - compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters) + from sqlalchemy.engine.default import DefaultDialect + dialect = DefaultDialect() + compiler = dialect.statement_compiler(dialect, self, parameters=parameters) compiler.compile() return compiler @@ -1727,7 +1548,7 @@ class FromClause(Selectable): def _get_all_embedded_columns(self): ret = [] - class FindCols(ClauseVisitor): + class FindCols(visitors.ClauseVisitor): def visit_column(self, col): ret.append(col) FindCols().traverse(self) @@ -1744,8 +1565,8 @@ class FromClause(Selectable): def replace_selectable(self, old, alias): """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - from sqlalchemy import sql_util - return sql_util.ClauseAdapter(alias).traverse(self, clone=True) + from sqlalchemy.sql import util + return util.ClauseAdapter(alias).traverse(self, clone=True) def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): """Given a ``ColumnElement``, return the exported ``ColumnElement`` @@ -2376,7 +2197,7 @@ class Join(FromClause): else: equivs[x] = util.Set([y]) - class BinaryVisitor(ClauseVisitor): + class BinaryVisitor(visitors.ClauseVisitor): def visit_binary(self, binary): if binary.operator == operators.eq: add_equiv(binary.left, binary.right) @@ -2460,7 +2281,7 @@ class Join(FromClause): return self.__folded_equivalents if equivs is None: equivs = util.Set() - class LocateEquivs(NoColumnVisitor): + class LocateEquivs(visitors.NoColumnVisitor): def visit_binary(self, binary): if binary.operator == operators.eq and binary.left.name == binary.right.name: equivs.add(binary.right) @@ -3331,7 +3152,7 @@ class Select(_SelectBaseMixin, FromClause): return intersect_all(self, other, **kwargs) def _table_iterator(self): - for t in NoColumnVisitor().iterate(self): + for t in visitors.NoColumnVisitor().iterate(self): if isinstance(t, TableClause): yield t diff --git a/lib/sqlalchemy/operators.py b/lib/sqlalchemy/sql/operators.py index b8aca3d26..b8aca3d26 100644 --- a/lib/sqlalchemy/operators.py +++ b/lib/sqlalchemy/sql/operators.py diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql/util.py index cc6325822..2c7294e66 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,7 +1,100 @@ -from sqlalchemy import sql, util, schema, topological +from sqlalchemy import util, schema, topological +from sqlalchemy.sql import expression, visitors """Utility functions that build upon SQL and Schema constructs.""" +class ClauseParameters(object): + """Represent a dictionary/iterator of bind parameter key names/values. + + Tracks the original [sqlalchemy.sql#_BindParamClause] objects as well as the + keys/position of each parameter, and can return parameters as a + dictionary or a list. Will process parameter values according to + the ``TypeEngine`` objects present in the ``_BindParamClause`` instances. + """ + + def __init__(self, dialect, positional=None): + self.dialect = dialect + self.__binds = {} + self.positional = positional or [] + + def get_parameter(self, key): + return self.__binds[key] + + def set_parameter(self, bindparam, value, name): + self.__binds[name] = [bindparam, name, value] + + def get_original(self, key): + return self.__binds[key][2] + + def get_type(self, key): + return self.__binds[key][0].type + + def get_processors(self): + """return a dictionary of bind 'processing' functions""" + return dict([ + (key, value) for key, value in + [( + key, + self.__binds[key][0].bind_processor(self.dialect) + ) for key in self.__binds] + if value is not None + ]) + + def get_processed(self, key, processors): + return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2] + + def keys(self): + return self.__binds.keys() + + def __iter__(self): + return iter(self.keys()) + + def __getitem__(self, key): + (bind, name, value) = self.__binds[key] + processor = bind.bind_processor(self.dialect) + return processor is not None and processor(value) or value + + def __contains__(self, key): + return key in self.__binds + + def set_value(self, key, value): + self.__binds[key][2] = value + + def get_original_dict(self): + return dict([(name, value) for (b, name, value) in self.__binds.values()]) + + def __get_processed(self, key, processors): + if key in processors: + return processors[key](self.__binds[key][2]) + else: + return self.__binds[key][2] + + def get_raw_list(self, processors): + return [self.__get_processed(key, processors) for key in self.positional] + + def get_raw_dict(self, processors, encode_keys=False): + if encode_keys: + return dict([ + ( + key.encode(self.dialect.encoding), + self.__get_processed(key, processors) + ) + for key in self.keys() + ]) + else: + return dict([ + ( + key, + self.__get_processed(key, processors) + ) + for key in self.keys() + ]) + + def __repr__(self): + return self.__class__.__name__ + ":" + repr(self.get_original_dict()) + + + class TableCollection(object): def __init__(self, tables=None): self.tables = tables or [] @@ -64,7 +157,7 @@ class TableCollection(object): return sequence -class TableFinder(TableCollection, sql.NoColumnVisitor): +class TableFinder(TableCollection, visitors.NoColumnVisitor): """locate all Tables within a clause.""" def __init__(self, clause, check_columns=False, include_aliases=False): @@ -85,7 +178,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor): if self.check_columns: self.tables.append(column.table) -class ColumnFinder(sql.ClauseVisitor): +class ColumnFinder(visitors.ClauseVisitor): def __init__(self): self.columns = util.Set() @@ -95,7 +188,7 @@ class ColumnFinder(sql.ClauseVisitor): def __iter__(self): return iter(self.columns) -class ColumnsInClause(sql.ClauseVisitor): +class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns from the clause are in the selectable. """ @@ -108,7 +201,7 @@ class ColumnsInClause(sql.ClauseVisitor): if self.selectable.c.get(column.key) is column: self.result = True -class AbstractClauseProcessor(sql.NoColumnVisitor): +class AbstractClauseProcessor(visitors.NoColumnVisitor): """Traverse a clause and attempt to convert the contents of container elements to a converted element. @@ -224,10 +317,10 @@ class ClauseAdapter(AbstractClauseProcessor): self.equivalents = equivalents def convert_element(self, col): - if isinstance(col, sql.FromClause): + if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): return self.selectable - if not isinstance(col, sql.ColumnElement): + if not isinstance(col, expression.ColumnElement): return None if self.include is not None: if col not in self.include: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py new file mode 100644 index 000000000..98e4de6c3 --- /dev/null +++ b/lib/sqlalchemy/sql/visitors.py @@ -0,0 +1,87 @@ +class ClauseVisitor(object): + """A class that knows how to traverse and visit + ``ClauseElements``. + + Calls visit_XXX() methods dynamically generated for each particualr + ``ClauseElement`` subclass encountered. Traversal of a + hierarchy of ``ClauseElements`` is achieved via the + ``traverse()`` method, which is passed the lead + ``ClauseElement``. + + By default, ``ClauseVisitor`` traverses all elements + fully. Options can be specified at the class level via the + ``__traverse_options__`` dictionary which will be passed + to the ``get_children()`` method of each ``ClauseElement``; + these options can indicate modifications to the set of + elements returned, such as to not return column collections + (column_collections=False) or to return Schema-level items + (schema_visitor=True). + + ``ClauseVisitor`` also supports a simultaneous copy-and-traverse + operation, which will produce a copy of a given ``ClauseElement`` + structure while at the same time allowing ``ClauseVisitor`` subclasses + to modify the new structure in-place. + + """ + __traverse_options__ = {} + + def traverse_single(self, obj, **kwargs): + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) + + def iterate(self, obj, stop_on=None): + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + yield t + traversal.insert(0, t) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + + def traverse(self, obj, stop_on=None, clone=False): + if clone: + obj = obj._clone() + + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + traversal.insert(0, t) + if clone: + t._copy_internals() + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + for target in traversal: + v = self + while v is not None: + meth = getattr(v, "visit_%s" % target.__visit_name__, None) + if meth: + meth(target) + v = getattr(v, '_next', None) + return obj + + def chain(self, visitor): + """'chain' an additional ClauseVisitor onto this ClauseVisitor. + + the chained visitor will receive all visit events after this one.""" + tail = self + while getattr(tail, '_next', None) is not None: + tail = tail._next + tail._next = visitor + return self + +class NoColumnVisitor(ClauseVisitor): + """a ClauseVisitor that will not traverse the exported Column + collections on Table, Alias, Select, and CompoundSelect objects + (i.e. their 'columns' or 'c' attribute). + + this is useful because most traversals don't need those columns, or + in the case of DefaultCompiler it traverses them explicitly; so + skipping their traversal here greatly cuts down on method call overhead. + """ + + __traverse_options__ = {'column_collections':False} diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index 46e0a7137..294842854 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -154,7 +154,7 @@ class TypesTest(AssertMixin): table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) + gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None) for col in numeric_table.c: index = int(col.name[1:]) @@ -238,7 +238,7 @@ class TypesTest(AssertMixin): table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) + gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None) for col in charset_table.c: index = int(col.name[1:]) @@ -707,7 +707,7 @@ class SQLTest(AssertMixin): def colspec(c): - return testbase.db.dialect.schemagenerator( + return testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None).get_column_specification(c) if __name__ == "__main__": diff --git a/test/engine/reflection.py b/test/engine/reflection.py index da6f75149..2345a328a 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -2,7 +2,6 @@ import testbase import pickle, StringIO, unicodedata from sqlalchemy import * -import sqlalchemy.ansisql as ansisql from sqlalchemy.exceptions import NoSuchTableError from testlib import * from testlib import engines @@ -686,7 +685,7 @@ class SchemaTest(PersistTest): def foo(s, p=None): buf.write(s) gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo) - gen = gen.dialect.schemagenerator(gen) + gen = gen.dialect.schemagenerator(gen.dialect, gen) gen.traverse(table1) gen.traverse(table2) buf = buf.getvalue() diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py index 0c824d372..2e84c83d1 100644 --- a/test/orm/dynamic.py +++ b/test/orm/dynamic.py @@ -1,7 +1,6 @@ import testbase import operator from sqlalchemy import * -from sqlalchemy import ansisql from sqlalchemy.orm import * from testlib import * from testlib.fixtures import * diff --git a/test/orm/query.py b/test/orm/query.py index e0b8bf4f3..e3f6ed42c 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1,7 +1,8 @@ import testbase import operator +from sqlalchemy.sql import compiler from sqlalchemy import * -from sqlalchemy import ansisql +from sqlalchemy.engine import default from sqlalchemy.orm import * from testlib import * from testlib.fixtures import * @@ -141,7 +142,7 @@ class OperatorTest(QueryTest): """test sql.Comparator implementation for MapperProperties""" def _test(self, clause, expected): - c = str(clause.compile(dialect=ansisql.ANSIDialect())) + c = str(clause.compile(dialect = default.DefaultDialect())) assert c == expected, "%s != %s" % (c, expected) def test_arithmetic(self): @@ -182,7 +183,7 @@ class OperatorTest(QueryTest): # the compiled clause should match either (e.g.): # 'a' < 'b' -or- 'b' > 'a'. - compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect())) + compiled = str(py_op(lhs, rhs).compile(dialect=default.DefaultDialect())) fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) @@ -201,7 +202,7 @@ class OperatorTest(QueryTest): # this one would require adding compile() to InstrumentedScalarAttribute. do we want this ? #(User.id, "users.id") ): - c = expr.compile(dialect=ansisql.ANSIDialect()) + c = expr.compile(dialect=default.DefaultDialect()) assert str(c) == compare, "%s != %s" % (str(c), compare) diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 3120185d5..a8b642b9b 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -179,7 +179,7 @@ class ConstraintTest(AssertMixin): capt.append(repr(context.parameters)) ex(context) connection._Connection__execute = proxy - schemagen = testbase.db.dialect.schemagenerator(connection) + schemagen = testbase.db.dialect.schemagenerator(testbase.db.dialect, connection) schemagen.traverse(events) assert capt[0].strip().startswith('CREATE TABLE events') diff --git a/test/sql/generative.py b/test/sql/generative.py index d79af577f..9bd97b305 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -1,6 +1,7 @@ import testbase from sqlalchemy import * from testlib import * +from sqlalchemy.sql.visitors import * class TraversalTest(AssertMixin): """test ClauseVisitor's traversal, particularly its ability to copy and modify @@ -213,7 +214,7 @@ class ClauseTest(SQLCompileTest): self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2") def test_clause_adapter(self): - from sqlalchemy import sql_util + from sqlalchemy.sql import util as sql_util t1alias = t1.alias('t1alias') diff --git a/test/sql/labels.py b/test/sql/labels.py index 553a3a3bc..dee76428d 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -78,7 +78,7 @@ class LongLabelsTest(PersistTest): # this is the test that fails if the "max identifier length" is shorter than the # length of the actual columns created, because the column names get truncated. # if you try to separate "physical columns" from "labels", and only truncate the labels, - # the ansisql.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code, + # the compiler.DefaultCompiler.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code, # since it is creating "labels" on the fly but not affecting derived columns, which think they are # still "physical" q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias('foo') diff --git a/test/sql/query.py b/test/sql/query.py index 4f569f1c0..8ec3190b4 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -1,7 +1,8 @@ import testbase import datetime from sqlalchemy import * -from sqlalchemy import exceptions, ansisql +from sqlalchemy import exceptions +from sqlalchemy.engine import default from testlib import * @@ -166,14 +167,14 @@ class QueryTest(PersistTest): assert len(r) == 1 def test_bindparam_detection(self): - dialect = ansisql.ANSIDialect(default_paramstyle='qmark') - prep = lambda q: dialect.compile(sql.text(q)).string + dialect = default.DefaultDialect(default_paramstyle='qmark') + prep = lambda q: str(sql.text(q).compile(dialect=dialect)) def a_eq(got, wanted): if got != wanted: print "Wanted %s" % wanted print "Received %s" % got - self.assert_(got == wanted) + self.assert_(got == wanted, got) a_eq(prep('select foo'), 'select foo') a_eq(prep("time='12:30:00'"), "time='12:30:00'") diff --git a/test/sql/quote.py b/test/sql/quote.py index ad25619df..0c414af3a 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -1,7 +1,7 @@ import testbase from sqlalchemy import * from testlib import * - +from sqlalchemy.sql import compiler class QuoteTest(PersistTest): def setUpAll(self): @@ -98,10 +98,10 @@ class QuoteTest(PersistTest): class PreparerTest(PersistTest): - """Test the db-agnostic quoting services of ANSIIdentifierPreparer.""" + """Test the db-agnostic quoting services of IdentifierPreparer.""" def test_unformat(self): - prep = ansisql.ANSIIdentifierPreparer(None) + prep = compiler.IdentifierPreparer(None) unformat = prep.unformat_identifiers def a_eq(have, want): @@ -120,7 +120,7 @@ class PreparerTest(PersistTest): a_eq(unformat('"foo"."b""a""r"."baz"'), ['foo', 'b"a"r', 'baz']) def test_unformat_custom(self): - class Custom(ansisql.ANSIIdentifierPreparer): + class Custom(compiler.IdentifierPreparer): def __init__(self, dialect): super(Custom, self).__init__(dialect, initial_quote='`', final_quote='`') diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index c497fbcbd..4ffb4c591 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -195,7 +195,7 @@ class ColumnsTest(AssertMixin): ) for aCol in testTable.c: - self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None, None).get_column_specification(aCol)) + self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db.dialect, db, None, None).get_column_specification(aCol)) class UnicodeTest(AssertMixin): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 6830fb63c..58fd7c0d1 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -161,7 +161,8 @@ class ExecutionContextWrapper(object): if params is not None and isinstance(params, list) and len(params) == 1: params = params[0] - if isinstance(ctx.compiled_parameters, sql.ClauseParameters): + from sqlalchemy.sql.util import ClauseParameters + if isinstance(ctx.compiled_parameters, ClauseParameters): parameters = ctx.compiled_parameters.get_original_dict() elif isinstance(ctx.compiled_parameters, list): parameters = [p.get_original_dict() for p in ctx.compiled_parameters] |