diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-18 21:37:48 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-18 21:37:48 +0000 |
commit | 7c6c1b99c2de00829b6f34ffba7e3bb689d34198 (patch) | |
tree | acd6f8dc84cea86fc58b195a5f1068cbe020e955 | |
parent | 534cf5fdbd05e2049ab9feceabf3926a5ab6380c (diff) | |
download | sqlalchemy-7c6c1b99c2de00829b6f34ffba7e3bb689d34198.tar.gz |
1. Module layout. sql.py and related move into a package called "sql".
2. compiler names changed to be less verbose, unused classes removed.
3. Methods on Dialect which return compilers, schema generators, identifier preparers
have changed to direct class references, typically on the Dialect class itself
or optionally as attributes on an individual Dialect instance if conditional behavior is needed.
This takes away the need for Dialect subclasses to know how to instantiate these
objects, and also reduces method overhead by one call for each one.
4. as a result of 3., some internal signatures have changed for things like compiler() (now statement_compiler()), preparer(), etc., mostly in that the dialect needs to be passed explicitly as the first argument (since they are just class references now). The compiler() method on Engine and Connection is now also named statement_compiler(), but as before does not take the dialect as an argument.
5. changed _process_row function on RowProxy to be a class reference, cuts out 50K method calls from insertspeed.py
36 files changed, 517 insertions, 607 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 6bf8b96e9..4aa773239 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -347,7 +347,7 @@ class AccessDialect(ansisql.ANSIDialect): return names -class AccessCompiler(ansisql.ANSICompiler): +class AccessCompiler(compiler.DefaultCompiler): def visit_select_precolumns(self, select): """Access puts TOP, it's version of LIMIT here """ s = select.distinct and "DISTINCT " or "" @@ -387,7 +387,7 @@ class AccessCompiler(ansisql.ANSICompiler): return '' -class AccessSchemaGenerator(ansisql.ANSISchemaGenerator): +class AccessSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -410,7 +410,7 @@ class AccessSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec -class AccessSchemaDropper(ansisql.ANSISchemaDropper): +class AccessSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name)) self.execute() @@ -418,7 +418,7 @@ class AccessSchemaDropper(ansisql.ANSISchemaDropper): class AccessDefaultRunner(ansisql.ANSIDefaultRunner): pass -class AccessIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class AccessIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 307fceb48..9cccb53e8 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -7,9 +7,10 @@ import warnings -from sqlalchemy import util, sql, schema, ansisql, exceptions -import sqlalchemy.engine.default as default -import sqlalchemy.types as sqltypes +from sqlalchemy import util, sql, schema, exceptions +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes _initialized_kb = False @@ -99,9 +100,9 @@ class FBExecutionContext(default.DefaultExecutionContext): return True -class FBDialect(ansisql.ANSIDialect): +class FBDialect(default.DefaultDialect): def __init__(self, type_conv=200, concurrency_level=1, **kwargs): - ansisql.ANSIDialect.__init__(self, **kwargs) + default.DefaultDialect.__init__(self, **kwargs) self.type_conv = type_conv self.concurrency_level= concurrency_level @@ -135,21 +136,6 @@ class FBDialect(ansisql.ANSIDialect): def supports_sane_rowcount(self): return False - def compiler(self, statement, bindparams, **kwargs): - return FBCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return FBSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return FBSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection): - return FBDefaultRunner(connection) - - def preparer(self): - return FBIdentifierPreparer(self) - def max_identifier_length(self): return 31 @@ -307,7 +293,7 @@ class FBDialect(ansisql.ANSIDialect): connection.commit(True) -class FBCompiler(ansisql.ANSICompiler): +class FBCompiler(compiler.DefaultCompiler): """Firebird specific idiosincrasies""" def visit_alias(self, alias, asfrom=False, **kwargs): @@ -346,7 +332,7 @@ class FBCompiler(ansisql.ANSICompiler): return "" -class FBSchemaGenerator(ansisql.ANSISchemaGenerator): +class FBSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -365,13 +351,13 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator): self.execute() -class FBSchemaDropper(ansisql.ANSISchemaDropper): +class FBSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): self.append("DROP GENERATOR %s" % sequence.name) self.execute() -class FBDefaultRunner(ansisql.ANSIDefaultRunner): +class FBDefaultRunner(base.DefaultRunner): def exec_default_sql(self, default): c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection) return self.connection.execute_compiled(c).scalar() @@ -421,7 +407,7 @@ RESERVED_WORDS = util.Set( "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) -class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class FBIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True) @@ -430,3 +416,9 @@ class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer): dialect = FBDialect +dialect.statement_compiler = FBCompiler +dialect.schemagenerator = FBSchemaGenerator +dialect.schemadropper = FBSchemaDropper +dialect.defaultrunner = FBDefaultRunner +dialect.preparer = FBIdentifierPreparer + diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 21ecf1538..ceb56903a 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -7,9 +7,10 @@ import datetime, warnings -from sqlalchemy import sql, schema, ansisql, exceptions, pool -import sqlalchemy.engine.default as default -import sqlalchemy.types as sqltypes +from sqlalchemy import sql, schema, exceptions, pool +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes # for offset @@ -203,11 +204,11 @@ class InfoExecutionContext(default.DefaultExecutionContext): def create_cursor( self ): return informix_cursor( self.connection.connection ) -class InfoDialect(ansisql.ANSIDialect): +class InfoDialect(default.DefaultDialect): def __init__(self, use_ansi=True,**kwargs): self.use_ansi = use_ansi - ansisql.ANSIDialect.__init__(self, **kwargs) + default.DefaultDialect.__init__(self, **kwargs) self.paramstyle = 'qmark' def dbapi(cls): @@ -252,18 +253,6 @@ class InfoDialect(ansisql.ANSIDialect): def oid_column_name(self,column): return "rowid" - def preparer(self): - return InfoIdentifierPreparer(self) - - def compiler(self, statement, bindparams, **kwargs): - return InfoCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return InfoSchemaGenerator( self , *args, **kwargs) - - def schemadropper(self, *args, **params): - return InfoSchemaDroper( self , *args , **params) - def table_names(self, connection, schema): s = "select tabname from systables" return [row[0] for row in connection.execute(s)] @@ -376,14 +365,14 @@ class InfoDialect(ansisql.ANSIDialect): for cons_name, cons_type, local_column in rows: table.primary_key.add( table.c[local_column] ) -class InfoCompiler(ansisql.ANSICompiler): +class InfoCompiler(compiler.DefaultCompiler): """Info compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" def __init__(self, dialect, statement, parameters=None, **kwargs): self.limit = 0 self.offset = 0 - ansisql.ANSICompiler.__init__( self , dialect , statement , parameters , **kwargs ) + compiler.DefaultCompiler.__init__( self , dialect , statement , parameters , **kwargs ) def default_from(self): return " from systables where tabname = 'systables' " @@ -416,7 +405,7 @@ class InfoCompiler(ansisql.ANSICompiler): if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid': select.append_column( c ) - return ansisql.ANSICompiler.visit_select(self, select) + return compiler.DefaultCompiler.visit_select(self, select) def limit_clause(self, select): return "" @@ -437,7 +426,7 @@ class InfoCompiler(ansisql.ANSICompiler): elif func.name.lower() in ( 'current_timestamp' , 'now' ): return "CURRENT YEAR TO SECOND" else: - return ansisql.ANSICompiler.visit_function( self , func ) + return compiler.DefaultCompiler.visit_function( self , func ) def visit_clauselist(self, list): try: @@ -446,7 +435,7 @@ class InfoCompiler(ansisql.ANSICompiler): li = [ c for c in list.clauses ] return ', '.join([s for s in [self.process(c) for c in li] if s is not None]) -class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): +class InfoSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, first_pk=False): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ @@ -507,7 +496,7 @@ class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): return super(InfoSchemaGenerator, self).visit_index(index) -class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class InfoIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") @@ -517,10 +506,14 @@ class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def _requires_quotes(self, value): return False -class InfoSchemaDroper(ansisql.ANSISchemaDropper): +class InfoSchemaDroper(compiler.SchemaDropper): def drop_foreignkey(self, constraint): if constraint.name is not None: super( InfoSchemaDroper , self ).drop_foreignkey( constraint ) dialect = InfoDialect poolclass = pool.SingletonThreadPool +dialect.statement_compiler = InfoCompiler +dialect.schemagenerator = InfoSchemaGenerator +dialect.schemadropper = InfoSchemaDropper +dialect.preparer = InfoIdentifierPreparer diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 619e072d9..0caccca95 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -39,10 +39,10 @@ Known issues / TODO: import datetime, random, warnings, re -from sqlalchemy import sql, schema, ansisql, exceptions -import sqlalchemy.types as sqltypes -from sqlalchemy.engine import default -import operator, sys +from sqlalchemy import util, sql, schema, exceptions +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes class MSNumeric(sqltypes.Numeric): def result_processor(self, dialect): @@ -366,7 +366,7 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): super(MSSQLExecutionContext_pyodbc, self).post_exec() -class MSSQLDialect(ansisql.ANSIDialect): +class MSSQLDialect(default.DefaultDialect): colspecs = { sqltypes.Unicode : MSNVarchar, sqltypes.Integer : MSInteger, @@ -476,21 +476,6 @@ class MSSQLDialect(ansisql.ANSIDialect): def supports_sane_rowcount(self): raise NotImplementedError() - def compiler(self, statement, bindparams, **kwargs): - return MSSQLCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return MSSQLSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return MSSQLSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection, **kwargs): - return MSSQLDefaultRunner(connection, **kwargs) - - def preparer(self): - return MSSQLIdentifierPreparer(self) - def get_default_schema_name(self, connection): return self.schema_name @@ -878,7 +863,7 @@ dialect_mapping = { } -class MSSQLCompiler(ansisql.ANSICompiler): +class MSSQLCompiler(compiler.DefaultCompiler): def __init__(self, dialect, statement, parameters, **kwargs): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) self.tablealiases = {} @@ -931,13 +916,13 @@ class MSSQLCompiler(ansisql.ANSICompiler): def visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq: - return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator)) + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) else: return super(MSSQLCompiler, self).visit_binary(binary) def label_select_column(self, select, column): - if isinstance(column, sql._Function): + if isinstance(column, expression._Function): return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) else: return super(MSSQLCompiler, self).label_select_column(select, column) @@ -963,7 +948,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): return "" -class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): +class MSSQLSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -986,7 +971,7 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec -class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): +class MSSQLSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s.%s" % ( self.preparer.quote_identifier(index.table.name), @@ -995,11 +980,11 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): self.execute() -class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): +class MSSQLDefaultRunner(base.DefaultRunner): # TODO: does ms-sql have standalone sequences ? pass -class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class MSSQLIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') @@ -1012,6 +997,11 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): return value dialect = MSSQLDialect +dialect.statement_compiler = MSSQLCompiler +dialect.schemagenerator = MSSQLSchemaGenerator +dialect.schemadropper = MSSQLSchemaDropper +dialect.preparer = MSSQLIdentifierPreparer +dialect.defaultrunner = MSSQLDefaultRunner diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index d5fd3b6c5..41c6ec70f 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -126,11 +126,12 @@ information affecting MySQL in SQLAlchemy. import re, datetime, inspect, warnings, sys from array import array as _array -from sqlalchemy import ansisql, exceptions, logging, schema, sql, util -from sqlalchemy import operators as sql_operators +from sqlalchemy import exceptions, logging, schema, sql, util +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy.sql import compiler from sqlalchemy.engine import base as engine_base, default -import sqlalchemy.types as sqltypes +from sqlalchemy import types as sqltypes __all__ = ( @@ -1328,13 +1329,17 @@ class MySQLExecutionContext(default.DefaultExecutionContext): return AUTOCOMMIT_RE.match(self.statement) -class MySQLDialect(ansisql.ANSIDialect): +class MySQLDialect(default.DefaultDialect): """Details of the MySQL dialect. Not used directly in application code.""" def __init__(self, use_ansiquotes=False, **kwargs): self.use_ansiquotes = use_ansiquotes kwargs.setdefault('default_paramstyle', 'format') - ansisql.ANSIDialect.__init__(self, **kwargs) + if self.use_ansiquotes: + self.preparer = MySQLANSIIdentifierPreparer + else: + self.preparer = MySQLIdentifierPreparer + default.DefaultDialect.__init__(self, **kwargs) def dbapi(cls): import MySQLdb as mysql @@ -1393,7 +1398,7 @@ class MySQLDialect(ansisql.ANSIDialect): return True def compiler(self, statement, bindparams, **kwargs): - return MySQLCompiler(self, statement, bindparams, **kwargs) + return MySQLCompiler(statement, bindparams, dialect=self, **kwargs) def schemagenerator(self, *args, **kwargs): return MySQLSchemaGenerator(self, *args, **kwargs) @@ -1401,12 +1406,6 @@ class MySQLDialect(ansisql.ANSIDialect): def schemadropper(self, *args, **kwargs): return MySQLSchemaDropper(self, *args, **kwargs) - def preparer(self): - if self.use_ansiquotes: - return MySQLANSIIdentifierPreparer(self) - else: - return MySQLIdentifierPreparer(self) - def do_executemany(self, cursor, statement, parameters, context=None, **kwargs): rowcount = cursor.executemany(statement, parameters) @@ -1733,8 +1732,8 @@ class _MySQLPythonRowProxy(object): return item -class MySQLCompiler(ansisql.ANSICompiler): - operators = ansisql.ANSICompiler.operators.copy() +class MySQLCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() operators.update( { sql_operators.concat_op: \ @@ -1783,7 +1782,7 @@ class MySQLCompiler(ansisql.ANSICompiler): # In older versions, the indexes must be created explicitly or the # creation of foreign key constraints fails." -class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): +class MySQLSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): """Builds column DDL.""" @@ -1827,7 +1826,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): return ' '.join(table_opts) -class MySQLSchemaDropper(ansisql.ANSISchemaDropper): +class MySQLSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s ON %s" % (self.preparer.format_index(index), @@ -2368,7 +2367,7 @@ class MySQLSchemaReflector(object): MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector) -class _MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): """MySQL-specific schema identifier configuration.""" def __init__(self, dialect, **kw): @@ -2433,3 +2432,6 @@ def _re_compile(regex): return re.compile(regex, re.I | re.UNICODE) dialect = MySQLDialect +dialect.statement_compiler = MySQLCompiler +dialect.schemagenerator = MySQLSchemaGenerator +dialect.schemadropper = MySQLSchemaDropper diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index a35db1982..2d8f2940f 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -5,11 +5,13 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, warnings, operator, random +import re, warnings, random -from sqlalchemy import util, sql, schema, ansisql, exceptions, logging +from sqlalchemy import util, sql, schema, exceptions, logging from sqlalchemy.engine import default, base -import sqlalchemy.types as sqltypes +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes import datetime @@ -229,9 +231,9 @@ class OracleExecutionContext(default.DefaultExecutionContext): return base.ResultProxy(self) -class OracleDialect(ansisql.ANSIDialect): +class OracleDialect(default.DefaultDialect): def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs) + default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs) self.use_ansi = use_ansi self.threaded = threaded self.allow_twophase = allow_twophase @@ -333,21 +335,6 @@ class OracleDialect(ansisql.ANSIDialect): def create_execution_context(self, *args, **kwargs): return OracleExecutionContext(self, *args, **kwargs) - def compiler(self, statement, bindparams, **kwargs): - return OracleCompiler(self, statement, bindparams, **kwargs) - - def preparer(self): - return OracleIdentifierPreparer(self) - - def schemagenerator(self, *args, **kwargs): - return OracleSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return OracleSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection, **kwargs): - return OracleDefaultRunner(connection, **kwargs) - def has_table(self, connection, table_name, schema=None): cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':self._denormalize_name(table_name)}) return bool( cursor.fetchone() is not None ) @@ -560,16 +547,16 @@ class _OuterJoinColumn(sql.ClauseElement): def __init__(self, column): self.column = column -class OracleCompiler(ansisql.ANSICompiler): +class OracleCompiler(compiler.DefaultCompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False. """ - operators = ansisql.ANSICompiler.operators.copy() + operators = compiler.DefaultCompiler.operators.copy() operators.update( { - operator.mod : lambda x, y:"mod(%s, %s)" % (x, y) + sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) } ) @@ -590,13 +577,13 @@ class OracleCompiler(ansisql.ANSICompiler): def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return ansisql.ANSICompiler.visit_join(self, join, **kwargs) + return compiler.DefaultCompiler.visit_join(self, join, **kwargs) (where, parentjoin) = self.__wheres.get(join, (None, None)) - class VisitOn(sql.ClauseVisitor): + class VisitOn(visitors.ClauseVisitor): def visit_binary(s, binary): - if binary.operator == operator.eq: + if binary.operator == sql_operators.eq: if binary.left.table is join.right: binary.left = _OuterJoinColumn(binary.left) elif binary.right.table is join.right: @@ -640,7 +627,7 @@ class OracleCompiler(ansisql.ANSICompiler): for c in insert.table.primary_key: if c.key not in self.parameters: self.parameters[c.key] = None - return ansisql.ANSICompiler.visit_insert(self, insert) + return compiler.DefaultCompiler.visit_insert(self, insert) def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" @@ -672,7 +659,7 @@ class OracleCompiler(ansisql.ANSICompiler): limitselect.append_whereclause("ora_rn<=%d" % select._limit) return self.process(limitselect, **kwargs) else: - return ansisql.ANSICompiler.visit_select(self, select, **kwargs) + return compiler.DefaultCompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" @@ -684,7 +671,7 @@ class OracleCompiler(ansisql.ANSICompiler): return super(OracleCompiler, self).for_update_clause(select) -class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): +class OracleSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -701,13 +688,13 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class OracleSchemaDropper(ansisql.ANSISchemaDropper): +class OracleSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class OracleDefaultRunner(ansisql.ANSIDefaultRunner): +class OracleDefaultRunner(base.DefaultRunner): def exec_default_sql(self, default): c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection) return self.connection.execute(c).scalar() @@ -715,10 +702,15 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def visit_sequence(self, seq): return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar() -class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class OracleIdentifierPreparer(compiler.IdentifierPreparer): def format_savepoint(self, savepoint): name = re.sub(r'^_+', '', savepoint.ident) return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) dialect = OracleDialect +dialect.statement_compiler = OracleCompiler +dialect.schemagenerator = OracleSchemaGenerator +dialect.schemadropper = OracleSchemaDropper +dialect.preparer = OracleIdentifierPreparer +dialect.defaultrunner = OracleDefaultRunner diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 74a3ef13f..29d84ad4d 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,11 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, random, warnings, operator +import re, random, warnings -from sqlalchemy import sql, schema, ansisql, exceptions, util +from sqlalchemy import sql, schema, exceptions, util from sqlalchemy.engine import base, default -import sqlalchemy.types as sqltypes +from sqlalchemy.sql import compiler +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes class PGInet(sqltypes.TypeEngine): @@ -220,9 +222,9 @@ class PGExecutionContext(default.DefaultExecutionContext): self._last_inserted_ids = [v for v in row] super(PGExecutionContext, self).post_exec() -class PGDialect(ansisql.ANSIDialect): +class PGDialect(default.DefaultDialect): def __init__(self, use_oids=False, server_side_cursors=False, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) + default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors self.paramstyle = 'pyformat' @@ -249,15 +251,6 @@ class PGDialect(ansisql.ANSIDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def compiler(self, statement, bindparams, **kwargs): - return PGCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return PGSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return PGSchemaDropper(self, *args, **kwargs) - def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) @@ -286,12 +279,6 @@ class PGDialect(ansisql.ANSIDialect): resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] - def defaultrunner(self, context, **kwargs): - return PGDefaultRunner(context, **kwargs) - - def preparer(self): - return PGIdentifierPreparer(self) - def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): self._default_schema_name = connection.scalar("select current_schema()", None) @@ -556,11 +543,11 @@ class PGDialect(ansisql.ANSIDialect): -class PGCompiler(ansisql.ANSICompiler): - operators = ansisql.ANSICompiler.operators.copy() +class PGCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() operators.update( { - operator.mod : '%%' + sql_operators.mod : '%%' } ) @@ -597,7 +584,7 @@ class PGCompiler(ansisql.ANSICompiler): else: return super(PGCompiler, self).for_update_clause(select) -class PGSchemaGenerator(ansisql.ANSISchemaGenerator): +class PGSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): @@ -620,13 +607,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGSchemaDropper(ansisql.ANSISchemaDropper): +class PGSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGDefaultRunner(ansisql.ANSIDefaultRunner): +class PGDefaultRunner(base.DefaultRunner): def get_column_default(self, column, isinsert=True): if column.primary_key: # passive defaults on primary keys have to be overridden @@ -642,7 +629,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) return self.connection.execute(exc).scalar() - return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) + return super(PGDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: @@ -650,7 +637,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): else: return None -class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class PGIdentifierPreparer(compiler.IdentifierPreparer): def _fold_identifier_case(self, value): return value.lower() @@ -660,3 +647,8 @@ class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): return value dialect = PGDialect +dialect.statement_compiler = PGCompiler +dialect.schemagenerator = PGSchemaGenerator +dialect.schemadropper = PGSchemaDropper +dialect.preparer = PGIdentifierPreparer +dialect.defaultrunner = PGDefaultRunner diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index d96422236..c2aced4d0 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -7,11 +7,12 @@ import re -from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault -import sqlalchemy.engine.default as default +from sqlalchemy import schema, exceptions, pool, PassiveDefault +from sqlalchemy.engine import default import sqlalchemy.types as sqltypes import datetime,time, warnings import sqlalchemy.util as util +from sqlalchemy.sql import compiler SELECT_REGEXP = re.compile(r'\s*(?:SELECT|PRAGMA)', re.I | re.UNICODE) @@ -172,10 +173,10 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): def is_select(self): return SELECT_REGEXP.match(self.statement) -class SQLiteDialect(ansisql.ANSIDialect): +class SQLiteDialect(default.DefaultDialect): def __init__(self, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs) + default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs) def vers(num): return tuple([int(x) for x in num.split('.')]) if self.dbapi is not None: @@ -195,24 +196,12 @@ class SQLiteDialect(ansisql.ANSIDialect): return sqlite dbapi = classmethod(dbapi) - def compiler(self, statement, bindparams, **kwargs): - return SQLiteCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return SQLiteSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return SQLiteSchemaDropper(self, *args, **kwargs) - def server_version_info(self, connection): return self.dbapi.sqlite_version_info def supports_alter(self): return False - def preparer(self): - return SQLiteIdentifierPreparer(self) - def create_connect_args(self, url): filename = url.database or ':memory:' @@ -255,7 +244,7 @@ class SQLiteDialect(ansisql.ANSIDialect): return (row is not None) def reflecttable(self, connection, table, include_columns): - c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("PRAGMA table_info(%s)" % self.identifier_preparer.format_table(table), {}) found_table = False while True: row = c.fetchone() @@ -295,7 +284,7 @@ class SQLiteDialect(ansisql.ANSIDialect): if not found_table: raise exceptions.NoSuchTableError(table.name) - c = connection.execute("PRAGMA foreign_key_list(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("PRAGMA foreign_key_list(%s)" % self.identifier_preparer.format_table(table), {}) fks = {} while True: row = c.fetchone() @@ -324,7 +313,7 @@ class SQLiteDialect(ansisql.ANSIDialect): for name, value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1])) # check for UNIQUE indexes - c = connection.execute("PRAGMA index_list(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("PRAGMA index_list(%s)" % self.identifier_preparer.format_table(table), {}) unique_indexes = [] while True: row = c.fetchone() @@ -343,7 +332,7 @@ class SQLiteDialect(ansisql.ANSIDialect): cols.append(row[2]) col = table.columns[row[2]] -class SQLiteCompiler(ansisql.ANSICompiler): +class SQLiteCompiler(compiler.DefaultCompiler): def visit_cast(self, cast): if self.dialect.supports_cast: return super(SQLiteCompiler, self).visit_cast(cast) @@ -369,7 +358,8 @@ class SQLiteCompiler(ansisql.ANSICompiler): # sqlite has no "FOR UPDATE" AFAICT return '' -class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): + +class SQLiteSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -391,12 +381,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): # else: # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) -class SQLiteSchemaDropper(ansisql.ANSISchemaDropper): +class SQLiteSchemaDropper(compiler.SchemaDropper): pass -class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True) dialect = SQLiteDialect dialect.poolclass = pool.SingletonThreadPool +dialect.statement_compiler = SQLiteCompiler +dialect.schemagenerator = SQLiteSchemaGenerator +dialect.schemadropper = SQLiteSchemaDropper +dialect.preparer = SQLiteIdentifierPreparer + diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index faeb00cc9..553c8df84 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -8,7 +8,8 @@ higher-level statement-construction, connection-management, execution and result contexts.""" -from sqlalchemy import exceptions, sql, schema, util, types, logging +from sqlalchemy import exceptions, schema, util, types, logging +from sqlalchemy.sql import expression, visitors import StringIO, sys @@ -35,6 +36,22 @@ class Dialect(object): encoding type of encoding to use for unicode, usually defaults to 'utf-8' + + schemagenerator + a [sqlalchemy.schema#SchemaVisitor] class which generates schemas. + + schemadropper + a [sqlalchemy.schema#SchemaVisitor] class which drops schemas. + + defaultrunner + a [sqlalchemy.schema#SchemaVisitor] class which executes defaults. + + statement_compiler + a [sqlalchemy.engine.base#Compiled] class used to compile SQL statements + + preparer + a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote + identifiers. """ def create_connect_args(self, url): @@ -105,48 +122,6 @@ class Dialect(object): raise NotImplementedError() - def schemagenerator(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas. - - connection - a [sqlalchemy.engine#Connection] to use for statement execution - - `schemagenerator()` is called via the `create()` method on Table, - Index, and others. - """ - - raise NotImplementedError() - - def schemadropper(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas. - - connection - a [sqlalchemy.engine#Connection] to use for statement execution - - `schemadropper()` is called via the `drop()` method on Table, - Index, and others. - """ - - raise NotImplementedError() - - def defaultrunner(self, execution_context): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults. - - execution_context - a [sqlalchemy.engine#ExecutionContext] to use for statement execution - - """ - - raise NotImplementedError() - - def compiler(self, statement, parameters): - """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters. - - The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler]. - - """ - - raise NotImplementedError() def server_version_info(self, connection): """Return a tuple of the database's version number.""" @@ -266,16 +241,6 @@ class Dialect(object): raise NotImplementedError() - - def compile(self, clauseelement, parameters=None): - """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect. - - Returns [sqlalchemy.sql#Compiled]. A convenience method which - flips around the compile() call on ``ClauseElement``. - """ - - return clauseelement.compile(dialect=self, parameters=parameters) - def is_disconnect(self, e): """Return True if the given DBAPI error indicates an invalid connection""" @@ -304,7 +269,7 @@ class ExecutionContext(object): DBAPI cursor procured from the connection compiled - if passed to constructor, sql.Compiled object being executed + if passed to constructor, sqlalchemy.engine.base.Compiled object being executed statement string version of the statement to be executed. Is either @@ -439,6 +404,9 @@ class Compiled(object): def __init__(self, dialect, statement, parameters, bind=None): """Construct a new ``Compiled`` object. + dialect + ``Dialect`` to compile against. + statement ``ClauseElement`` to be compiled. @@ -724,8 +692,8 @@ class Connection(Connectable): def scalar(self, object, *multiparams, **params): return self.execute(object, *multiparams, **params).scalar() - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, parameters, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) def execute(self, object, *multiparams, **params): for c in type(object).__mro__: @@ -822,9 +790,9 @@ class Connection(Connectable): # poor man's multimethod/generic function thingy executors = { - sql._Function : _execute_function, - sql.ClauseElement : _execute_clauseelement, - sql.ClauseVisitor : _execute_compiled, + expression._Function : _execute_function, + expression.ClauseElement : _execute_clauseelement, + visitors.ClauseVisitor : _execute_compiled, schema.SchemaItem:_execute_default, str.__mro__[-2] : _execute_text } @@ -989,14 +957,14 @@ class Engine(Connectable): connection.close() def _func(self): - return sql._FunctionGenerator(bind=self) + return expression._FunctionGenerator(bind=self) func = property(_func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" - return sql.text(text, bind=self, *args, **kwargs) + return expression.text(text, bind=self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: @@ -1004,7 +972,7 @@ class Engine(Connectable): else: conn = connection try: - visitorcallable(conn, **kwargs).traverse(element) + visitorcallable(self.dialect, conn, **kwargs).traverse(element) finally: if connection is None: conn.close() @@ -1057,8 +1025,8 @@ class Engine(Connectable): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, parameters, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1159,6 +1127,7 @@ class ResultProxy(object): self.closed = False self.cursor = context.cursor self.__echo = logging.is_debug_enabled(context.engine.logger) + self._process_row = self._row_processor() if context.is_select(): self._init_metadata() self._rowcount = None @@ -1222,7 +1191,7 @@ class ResultProxy(object): rec = props[key] elif isinstance(key, basestring) and key.lower() in props: rec = props[key.lower()] - elif isinstance(key, sql.ColumnElement): + elif isinstance(key, expression.ColumnElement): label = context.column_labels.get(key._label, key.name).lower() if label in props: rec = props[label] @@ -1320,21 +1289,21 @@ class ResultProxy(object): return self.cursor.fetchmany(size) def _fetchall_impl(self): return self.cursor.fetchall() + + def _row_processor(self): + return RowProxy - def _process_row(self, row): - return RowProxy(self, row) - def fetchall(self): """Fetch all rows, just like DBAPI ``cursor.fetchall()``.""" - l = [self._process_row(row) for row in self._fetchall_impl()] + l = [self._process_row(self, row) for row in self._fetchall_impl()] self.close() return l def fetchmany(self, size=None): """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``.""" - l = [self._process_row(row) for row in self._fetchmany_impl(size)] + l = [self._process_row(self, row) for row in self._fetchmany_impl(size)] if len(l) == 0: self.close() return l @@ -1343,7 +1312,7 @@ class ResultProxy(object): """Fetch one row, just like DBAPI ``cursor.fetchone()``.""" row = self._fetchone_impl() if row is not None: - return self._process_row(row) + return self._process_row(self, row) else: self.close() return None @@ -1353,7 +1322,7 @@ class ResultProxy(object): row = self._fetchone_impl() try: if row is not None: - return self._process_row(row)[0] + return self._process_row(self, row)[0] else: return None finally: @@ -1425,11 +1394,9 @@ class BufferedColumnResultProxy(ResultProxy): def _get_col(self, row, key): rec = self._key_cache[key] return row[rec[2]] - - def _process_row(self, row): - sup = super(BufferedColumnResultProxy, self) - row = [sup._get_col(row, i) for i in xrange(len(row))] - return RowProxy(self, row) + + def _row_processor(self): + return BufferedColumnRow def fetchall(self): l = [] @@ -1523,6 +1490,11 @@ class RowProxy(object): def __len__(self): return len(self.__row) +class BufferedColumnRow(RowProxy): + def __init__(self, parent, row): + row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))] + super(BufferedColumnRow, self).__init__(parent, row) + class SchemaIterator(schema.SchemaVisitor): """A visitor that can gather text into a buffer and execute the contents of the buffer.""" @@ -1590,11 +1562,11 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg]).compile(bind=self.connection) + c = expression.select([default.arg]).compile(bind=self.connection) return self.connection._execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): - if isinstance(onupdate.arg, sql.ClauseElement): + if isinstance(onupdate.arg, expression.ClauseElement): return self.exec_default_sql(onupdate) elif callable(onupdate.arg): return onupdate.arg(self.context) @@ -1602,7 +1574,7 @@ class DefaultRunner(schema.SchemaVisitor): return onupdate.arg def visit_column_default(self, default): - if isinstance(default.arg, sql.ClauseElement): + if isinstance(default.arg, expression.ClauseElement): return self.exec_default_sql(default) elif callable(default.arg): return default.arg(self.context) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ccaf080e7..059395921 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -9,7 +9,7 @@ from sqlalchemy import schema, exceptions, sql, util import re, random from sqlalchemy.engine import base - +from sqlalchemy.sql import compiler, expression AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) @@ -18,6 +18,12 @@ SELECT_REGEXP = re.compile(r'\s*SELECT', re.I | re.UNICODE) class DefaultDialect(base.Dialect): """Default implementation of Dialect""" + schemagenerator = compiler.SchemaGenerator + schemadropper = compiler.SchemaDropper + statement_compiler = compiler.DefaultCompiler + preparer = compiler.IdentifierPreparer + defaultrunner = base.DefaultRunner + def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode self.encoding = encoding @@ -25,6 +31,7 @@ class DefaultDialect(base.Dialect): self._ischema = None self.dbapi = dbapi self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) + self.identifier_preparer = self.preparer(self) def dbapi_type_map(self): # most DBAPIs have problems with this (such as, psycocpg2 types @@ -46,6 +53,7 @@ class DefaultDialect(base.Dialect): typeobj = typeobj() return typeobj + def supports_unicode_statements(self): """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" return False @@ -96,13 +104,13 @@ class DefaultDialect(base.Dialect): return "_sa_%032x" % random.randint(0,2**128) def do_savepoint(self, connection, name): - connection.execute(sql.SavepointClause(name)) + connection.execute(expression.SavepointClause(name)) def do_rollback_to_savepoint(self, connection, name): - connection.execute(sql.RollbackToSavepointClause(name)) + connection.execute(expression.RollbackToSavepointClause(name)) def do_release_savepoint(self, connection, name): - connection.execute(sql.ReleaseSavepointClause(name)) + connection.execute(expression.ReleaseSavepointClause(name)) def do_executemany(self, cursor, statement, parameters, **kwargs): cursor.executemany(statement, parameters) @@ -110,8 +118,6 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters) - def defaultrunner(self, context): - return base.DefaultRunner(context) def is_disconnect(self, e): return False diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 65d2ab3d2..258bddb4a 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -294,7 +294,7 @@ from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext from sqlalchemy.exceptions import * - +from sqlalchemy.sql import expression _testsql = """ CREATE TABLE books ( @@ -415,7 +415,7 @@ def _selectable_name(selectable): return x def class_for_table(selectable, **mapper_kwargs): - selectable = sql._selectable(selectable) + selectable = expression._selectable(selectable) mapname = 'Mapped' + _selectable_name(selectable) if isinstance(selectable, Table): klass = TableClassType(mapname, (object,), {}) @@ -499,7 +499,7 @@ class SqlSoup: def with_labels(self, item): # TODO give meaningful aliases - return self.map(sql._selectable(item).select(use_labels=True).alias('foo')) + return self.map(expression._selectable(item).select(use_labels=True).alias('foo')) def join(self, *args, **kwargs): j = join(*args, **kwargs) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c54eee438..9000a8df5 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -6,6 +6,7 @@ from sqlalchemy import util, logging, sql +from sqlalchemy.sql import expression __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension', 'MapperProperty', 'PropComparator', 'StrategizedProperty', @@ -363,7 +364,7 @@ class MapperProperty(object): return operator(self.comparator, value) -class PropComparator(sql.ColumnOperators): +class PropComparator(expression.ColumnOperators): """defines comparison operations for MapperProperty objects""" def expression_element(self): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b48368417..60d4526ec 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -6,7 +6,8 @@ import weakref, warnings, operator from sqlalchemy import sql, util, exceptions, logging -from sqlalchemy import sql_util as sqlutil +from sqlalchemy.sql import expression +from sqlalchemy.sql import util as sqlutil from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.util import ExtensionCarrier from sqlalchemy.orm import sync @@ -77,7 +78,7 @@ class Mapper(object): raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) for table in (local_table, select_table): - if table is not None and isinstance(table, sql._SelectBaseMixin): + if table is not None and isinstance(table, expression._SelectBaseMixin): # some db's, noteably postgres, dont want to select from a select # without an alias. also if we make our own alias internally, then # the configured properties on the mapper are not matched against the alias @@ -438,7 +439,7 @@ class Mapper(object): # against the "mapped_table" of this mapper. equivalent_columns = self._get_equivalent_columns() - primary_key = sql.ColumnSet() + primary_key = expression.ColumnSet() for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): c = self.mapped_table.corresponding_column(col, raiseerr=False) @@ -644,9 +645,9 @@ class Mapper(object): props = {} if self.properties is not None: for key, prop in self.properties.iteritems(): - if sql.is_column(prop): + if expression.is_column(prop): props[key] = self.select_table.corresponding_column(prop) - elif (isinstance(prop, list) and sql.is_column(prop[0])): + elif (isinstance(prop, list) and expression.is_column(prop[0])): props[key] = [self.select_table.corresponding_column(c) for c in prop] self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument) @@ -768,7 +769,7 @@ class Mapper(object): def _create_prop_from_column(self, column): column = util.to_list(column) - if not sql.is_column(column[0]): + if not expression.is_column(column[0]): return None mapped_column = [] for c in column: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 670fcccc9..20cbcb235 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -11,7 +11,8 @@ operations. PropertyLoader also relies upon the dependency.py module to handle flush-time dependency sorting and processing. """ -from sqlalchemy import sql, schema, util, exceptions, sql_util, logging +from sqlalchemy import sql, schema, util, exceptions, logging +from sqlalchemy.sql import util as sql_util from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 44329468f..5cbe19ce2 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -4,7 +4,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy import sql, util, exceptions, logging +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import expression, visitors from sqlalchemy.orm import mapper, object_mapper from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.interfaces import OperationContext, LoaderStack @@ -312,7 +314,7 @@ class Query(object): clause = self._from_obj[-1] currenttables = [clause] - class FindJoinedTables(sql.NoColumnVisitor): + class FindJoinedTables(visitors.NoColumnVisitor): def visit_join(self, join): currenttables.append(join.left) currenttables.append(join.right) @@ -836,7 +838,7 @@ class Query(object): # if theres an order by, add those columns to the column list # of the "rowcount" query we're going to make if order_by: - order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []] + order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] cf = sql_util.ColumnFinder() for o in order_by: cf.traverse(o) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 6565c8d77..bdb17e1d6 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -6,7 +6,9 @@ """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" -from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy import sql, util, exceptions, logging +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import visitors from sqlalchemy.orm import mapper, attributes from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption from sqlalchemy.orm import session as sessionlib @@ -292,7 +294,7 @@ class LazyLoader(AbstractRelationLoader): (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) - class Visitor(sql.ClauseVisitor): + class Visitor(visitors.ClauseVisitor): def visit_bindparam(s, bindparam): mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent if bindparam.key in bind_to_col: @@ -396,7 +398,7 @@ class LazyLoader(AbstractRelationLoader): if not isinstance(expr, sql.ColumnElement): return None columns = [] - class FindColumnInColumnClause(sql.ClauseVisitor): + class FindColumnInColumnClause(visitors.ClauseVisitor): def visit_column(self, c): columns.append(c) FindColumnInColumnClause().traverse(expr) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index cf48202b0..49661a95e 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -10,9 +10,9 @@ clause that compares column values. """ from sqlalchemy import sql, schema, exceptions +from sqlalchemy.sql import visitors, operators from sqlalchemy import logging from sqlalchemy.orm import util as mapperutil -import operator ONETOMANY = 0 MANYTOONE = 1 @@ -43,7 +43,7 @@ class ClauseSynchronizer(object): def compile_binary(binary): """Assemble a SyncRule given a single binary condition.""" - if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return source_column = None @@ -144,7 +144,7 @@ class SyncRule(object): SyncRule.logger = logging.class_logger(SyncRule) -class BinaryVisitor(sql.ClauseVisitor): +class BinaryVisitor(visitors.ClauseVisitor): def __init__(self, func): self.func = func diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 6b0956dc4..b3f58c954 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -4,7 +4,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util +from sqlalchemy import sql, util, exceptions +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import visitors from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", @@ -161,7 +163,7 @@ class ExtensionCarrier(MapperExtension): before_delete = _create_do('before_delete') after_delete = _create_do('after_delete') -class BinaryVisitor(sql.ClauseVisitor): +class BinaryVisitor(visitors.ClauseVisitor): def __init__(self, func): self.func = func @@ -196,7 +198,7 @@ class AliasedClauses(object): # for column-level subqueries, swap out its selectable with our # eager version as appropriate, and manually build the # "correlation" list of the subquery. - class ModifySubquery(sql.ClauseVisitor): + class ModifySubquery(visitors.ClauseVisitor): def visit_select(s, select): select._should_correlate = False select.append_correlation(self.alias) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 99803d665..99ca2389b 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -18,8 +18,11 @@ objects as well as the visitor interface, so that the schema package """ import re, inspect -from sqlalchemy import sql, types, exceptions, util, databases +from sqlalchemy import types, exceptions, util, databases +from sqlalchemy.sql import expression, visitors import sqlalchemy + + URL = None __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', @@ -31,7 +34,7 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', class SchemaItem(object): """Base class for items that define a database schema.""" - __metaclass__ = sql._FigureVisitName + __metaclass__ = expression._FigureVisitName def _init_items(self, *args): """Initialize the list of child items for this SchemaItem.""" @@ -84,7 +87,7 @@ def _get_table_key(name, schema): else: return schema + "." + name -class _TableSingleton(sql._FigureVisitName): +class _TableSingleton(expression._FigureVisitName): """A metaclass used by the ``Table`` object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): @@ -124,10 +127,10 @@ class _TableSingleton(sql._FigureVisitName): return table -class Table(SchemaItem, sql.TableClause): +class Table(SchemaItem, expression.TableClause): """Represent a relational database table. - This subclasses ``sql.TableClause`` to provide a table that is + This subclasses ``expression.TableClause`` to provide a table that is associated with an instance of ``MetaData``, which in turn may be associated with an instance of ``Engine``. @@ -229,7 +232,7 @@ class Table(SchemaItem, sql.TableClause): self.schema = kwargs.pop('schema', None) self.indexes = util.Set() self.constraints = util.Set() - self._columns = sql.ColumnCollection() + self._columns = expression.ColumnCollection() self.primary_key = PrimaryKeyConstraint() self._foreign_keys = util.OrderedSet() self.quote = kwargs.pop('quote', False) @@ -291,7 +294,7 @@ class Table(SchemaItem, sql.TableClause): def get_children(self, column_collections=True, schema_visitor=False, **kwargs): if not schema_visitor: - return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs) + return expression.TableClause.get_children(self, column_collections=column_collections, **kwargs) else: if column_collections: return [c for c in self.columns] @@ -338,10 +341,10 @@ class Table(SchemaItem, sql.TableClause): args.append(c.copy()) return Table(self.name, metadata, schema=schema, *args) -class Column(SchemaItem, sql._ColumnClause): +class Column(SchemaItem, expression._ColumnClause): """Represent a column in a database table. - This is a subclass of ``sql.ColumnClause`` and represents an + This is a subclass of ``expression.ColumnClause`` and represents an actual existing table in the database, in a similar fashion as ``TableClause``/``Table``. """ @@ -575,7 +578,7 @@ class Column(SchemaItem, sql._ColumnClause): return [x for x in (self.default, self.onupdate) if x is not None] + \ list(self.foreign_keys) + list(self.constraints) else: - return sql._ColumnClause.get_children(self, **kwargs) + return expression._ColumnClause.get_children(self, **kwargs) class ForeignKey(SchemaItem): @@ -806,7 +809,7 @@ class Constraint(SchemaItem): def __init__(self, name=None): self.name = name - self.columns = sql.ColumnCollection() + self.columns = expression.ColumnCollection() def __contains__(self, x): return self.columns.contains_column(x) @@ -1124,12 +1127,12 @@ class MetaData(SchemaItem): del self.tables[table.key] def table_iterator(self, reverse=True, tables=None): - import sqlalchemy.sql_util + from sqlalchemy.sql import util as sql_util if tables is None: tables = self.tables.values() else: tables = util.Set(tables).intersection(self.tables.values()) - sorter = sqlalchemy.sql_util.TableCollection(list(tables)) + sorter = sql_util.TableCollection(list(tables)) return iter(sorter.sort(reverse=reverse)) def _get_parent(self): @@ -1356,7 +1359,7 @@ class ThreadLocalMetaData(MetaData): e.dispose() -class SchemaVisitor(sql.ClauseVisitor): +class SchemaVisitor(visitors.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" __traverse_options__ = {'schema_visitor':True} diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py new file mode 100644 index 000000000..06b9f1f75 --- /dev/null +++ b/lib/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.sql.expression import * +from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor + diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/sql/compiler.py index 5f5e1c171..6053c72be 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1,22 +1,18 @@ -# ansisql.py +# compiler.py # Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Defines ANSI SQL operations. +"""SQL expression compilation routines and DDL implementations.""" -Contains default implementations for the abstract objects in the sql -module. -""" +import string, re +from sqlalchemy import schema, engine, util, exceptions +from sqlalchemy.sql import operators, visitors +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import expression as sql -import string, re, sets, operator - -from sqlalchemy import schema, sql, engine, util, exceptions, operators -from sqlalchemy.engine import default - - -ANSI_FUNCS = sets.ImmutableSet([ +ANSI_FUNCS = util.Set([ 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', 'SESSION_USER', 'USER']) @@ -77,7 +73,6 @@ OPERATORS = { operators.comma_op : ', ', operators.desc_op : 'DESC', operators.asc_op : 'ASC', - operators.from_ : 'FROM', operators.as_ : 'AS', operators.exists : 'EXISTS', @@ -85,36 +80,10 @@ OPERATORS = { operators.isnot : 'IS NOT' } -class ANSIDialect(default.DefaultDialect): - def __init__(self, cache_identifiers=True, **kwargs): - super(ANSIDialect,self).__init__(**kwargs) - self.identifier_preparer = self.preparer() - self.cache_identifiers = cache_identifiers - - def create_connect_args(self): - return ([],{}) - - def schemagenerator(self, *args, **kwargs): - return ANSISchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return ANSISchemaDropper(self, *args, **kwargs) - - def compiler(self, statement, parameters, **kwargs): - return ANSICompiler(self, statement, parameters, **kwargs) - - def preparer(self): - """Return an IdentifierPreparer. - - This object is used to format table and column names including - proper quoting and case conventions. - """ - return ANSIIdentifierPreparer(self) - -class ANSICompiler(engine.Compiled, sql.ClauseVisitor): +class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): """Default implementation of Compiled. - Compiles ClauseElements into ANSI-compliant SQL strings. + Compiles ClauseElements into SQL strings. """ __traverse_options__ = {'column_collections':False, 'entry':True} @@ -122,7 +91,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): operators = OPERATORS def __init__(self, dialect, statement, parameters=None, **kwargs): - """Construct a new ``ANSICompiler`` object. + """Construct a new ``DefaultCompiler`` object. dialect Dialect to be used @@ -139,7 +108,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): correspond to the keys present in the parameters. """ - super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs) + super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs) # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False @@ -170,17 +139,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): self.bindtemplate = ":%s" # paramstyle from the dialect (comes from DBAPI) - self.paramstyle = dialect.paramstyle + self.paramstyle = self.dialect.paramstyle # true if the paramstyle is positional - self.positional = dialect.positional + self.positional = self.dialect.positional # a list of the compiled's bind parameter names, used to help # formulate a positional argument list self.positiontup = [] - # an ANSIIdentifierPreparer that formats the quoting of identifiers - self.preparer = dialect.identifier_preparer + # an IdentifierPreparer that formats the quoting of identifiers + self.preparer = self.dialect.identifier_preparer # for UPDATE and INSERT statements, a set of columns whos values are being set # from a SQL expression (i.e., not one of the bind parameter values). if present, @@ -244,7 +213,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return None def construct_params(self, params): - """Return a sql.ClauseParameters object. + """Return a sql.util.ClauseParameters object. Combines the given bind parameter dictionary (string keys to object values) with the _BindParamClause objects stored within this Compiled object @@ -252,7 +221,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): for a single statement execution, or one element of an executemany execution. """ - d = sql.ClauseParameters(self.dialect, self.positiontup) + d = sql_util.ClauseParameters(self.dialect, self.positiontup) pd = self.parameters or {} pd.update(params) @@ -781,7 +750,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): def __str__(self): return self.string -class ANSISchemaBase(engine.SchemaIterator): +class DDLBase(engine.SchemaIterator): def find_alterables(self, tables): alterables = [] class FindAlterables(schema.SchemaVisitor): @@ -794,12 +763,12 @@ class ANSISchemaBase(engine.SchemaIterator): findalterables.traverse(c) return alterables -class ANSISchemaGenerator(ANSISchemaBase): +class SchemaGenerator(DDLBase): def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaGenerator, self).__init__(connection, **kwargs) + super(SchemaGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables and util.Set(tables) or None - self.preparer = dialect.preparer() + self.preparer = dialect.identifier_preparer self.dialect = dialect def get_column_specification(self, column, first_pk=False): @@ -860,7 +829,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def _compile(self, tocompile, parameters): """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.dialect.compiler(tocompile, parameters) + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) compiler.compile() return compiler @@ -930,12 +899,12 @@ class ANSISchemaGenerator(ANSISchemaBase): string.join([preparer.format_column(c) for c in index.columns], ', '))) self.execute() -class ANSISchemaDropper(ANSISchemaBase): +class SchemaDropper(DDLBase): def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaDropper, self).__init__(connection, **kwargs) + super(SchemaDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables - self.preparer = dialect.preparer() + self.preparer = dialect.identifier_preparer self.dialect = dialect def visit_metadata(self, metadata): @@ -964,14 +933,11 @@ class ANSISchemaDropper(ANSISchemaBase): self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() -class ANSIDefaultRunner(engine.DefaultRunner): - pass - -class ANSIIdentifierPreparer(object): +class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): - """Construct a new ``ANSIIdentifierPreparer`` object. + """Construct a new ``IdentifierPreparer`` object. initial_quote Character that begins a delimited identifier. @@ -1049,20 +1015,14 @@ class ANSIIdentifierPreparer(object): def __generic_obj_format(self, obj, ident): if getattr(obj, 'quote', False): return self.quote_identifier(ident) - if self.dialect.cache_identifiers: - try: - return self.__strings[ident] - except KeyError: - if self._requires_quotes(ident): - self.__strings[ident] = self.quote_identifier(ident) - else: - self.__strings[ident] = ident - return self.__strings[ident] - else: + try: + return self.__strings[ident] + except KeyError: if self._requires_quotes(ident): - return self.quote_identifier(ident) + self.__strings[ident] = self.quote_identifier(ident) else: - return ident + self.__strings[ident] = ident + return self.__strings[ident] def should_quote(self, object): return object.quote or self._requires_quotes(object.name) @@ -1152,5 +1112,3 @@ class ANSIIdentifierPreparer(object): return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] - -dialect = ANSIDialect diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql/expression.py index 7c73f7cb7..e117e3f47 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,13 +26,14 @@ classes usually have few or no public methods and are less guaranteed to stay the same in future releases. """ -from sqlalchemy import util, exceptions, operators +from sqlalchemy import util, exceptions +from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes import re __all__ = [ - 'Alias', 'ClauseElement', 'ClauseParameters', - 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', + 'Alias', 'ClauseElement', + 'ColumnCollection', 'ColumnElement', 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between', 'bindparam', 'case', 'cast', 'column', 'delete', @@ -810,187 +811,6 @@ def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) -class ClauseParameters(object): - """Represent a dictionary/iterator of bind parameter key names/values. - - Tracks the original [sqlalchemy.sql#_BindParamClause] objects - as well as the keys/position of each parameter, and can return - parameters as a dictionary or a list. Will process parameter - values according to the ``TypeEngine`` objects present in the - ``_BindParamClause`` instances. - """ - - def __init__(self, dialect, positional=None): - self.dialect = dialect - self.__binds = {} - self.positional = positional or [] - - def get_parameter(self, key): - return self.__binds[key] - - def set_parameter(self, bindparam, value, name): - self.__binds[name] = [bindparam, name, value] - - def get_original(self, key): - return self.__binds[key][2] - - def get_type(self, key): - return self.__binds[key][0].type - - def get_processors(self): - """return a dictionary of bind 'processing' functions""" - return dict([ - (key, value) for key, value in - [( - key, - self.__binds[key][0].bind_processor(self.dialect) - ) for key in self.__binds] - if value is not None - ]) - - def get_processed(self, key, processors): - return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2] - - def keys(self): - return self.__binds.keys() - - def __iter__(self): - return iter(self.keys()) - - def __getitem__(self, key): - (bind, name, value) = self.__binds[key] - processor = bind.bind_processor(self.dialect) - return processor is not None and processor(value) or value - - def __contains__(self, key): - return key in self.__binds - - def set_value(self, key, value): - self.__binds[key][2] = value - - def get_original_dict(self): - return dict([(name, value) for (b, name, value) in self.__binds.values()]) - - def __get_processed(self, key, processors): - if key in processors: - return processors[key](self.__binds[key][2]) - else: - return self.__binds[key][2] - - def get_raw_list(self, processors): - return [self.__get_processed(key, processors) for key in self.positional] - - def get_raw_dict(self, processors, encode_keys=False): - if encode_keys: - return dict([ - ( - key.encode(self.dialect.encoding), - self.__get_processed(key, processors) - ) - for key in self.keys() - ]) - else: - return dict([ - ( - key, - self.__get_processed(key, processors) - ) - for key in self.keys() - ]) - - def __repr__(self): - return self.__class__.__name__ + ":" + repr(self.get_original_dict()) - -class ClauseVisitor(object): - """A class that knows how to traverse and visit ``ClauseElements``. - - Calls visit_XXX() methods dynamically generated for each - particualr ``ClauseElement`` subclass encountered. Traversal of a - hierarchy of ``ClauseElements`` is achieved via the ``traverse()`` - method, which is passed the lead ``ClauseElement``. - - By default, ``ClauseVisitor`` traverses all elements fully. - Options can be specified at the class level via the - ``__traverse_options__`` dictionary which will be passed to the - ``get_children()`` method of each ``ClauseElement``; these options - can indicate modifications to the set of elements returned, such - as to not return column collections (column_collections=False) or - to return Schema-level items (schema_visitor=True). - - ``ClauseVisitor`` also supports a simultaneous copy-and-traverse - operation, which will produce a copy of a given ``ClauseElement`` - structure while at the same time allowing ``ClauseVisitor`` - subclasses to modify the new structure in-place. - """ - - __traverse_options__ = {} - - def traverse_single(self, obj, **kwargs): - meth = getattr(self, "visit_%s" % obj.__visit_name__, None) - if meth: - return meth(obj, **kwargs) - - def iterate(self, obj, stop_on=None): - stack = [obj] - traversal = [] - while len(stack) > 0: - t = stack.pop() - if stop_on is None or t not in stop_on: - yield t - traversal.insert(0, t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - - def traverse(self, obj, stop_on=None, clone=False): - if clone: - obj = obj._clone() - - stack = [obj] - traversal = [] - while len(stack) > 0: - t = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, t) - if clone: - t._copy_internals() - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - for target in traversal: - v = self - while v is not None: - meth = getattr(v, "visit_%s" % target.__visit_name__, None) - if meth: - meth(target) - v = getattr(v, '_next', None) - return obj - - def chain(self, visitor): - """'chain' an additional ClauseVisitor onto this ClauseVisitor. - - The chained visitor will receive all visit events after this one. - """ - - tail = self - while getattr(tail, '_next', None) is not None: - tail = tail._next - tail._next = visitor - return self - -class NoColumnVisitor(ClauseVisitor): - """A ClauseVisitor that will not traverse exported column collections. - - Will not traverse the exported Column collections on Table, Alias, - Select, and CompoundSelect objects (i.e. their 'columns' or 'c' - attribute). - - This is useful because most traversals don't need those columns, - or in the case of ANSICompiler it traverses them explicitly; so - skipping their traversal here greatly cuts down on method call - overhead. - """ - - __traverse_options__ = {'column_collections': False} - class _FigureVisitName(type): def __init__(cls, clsname, bases, dict): @@ -1061,7 +881,7 @@ class ClauseElement(object): elif len(optionaldict) > 1: raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument") - class Vis(ClauseVisitor): + class Vis(visitors.ClauseVisitor): def visit_bindparam(self, bind): if bind.key in kwargs: bind.value = kwargs[bind.key] @@ -1156,7 +976,7 @@ class ClauseElement(object): if any. Finally, if there is no bound ``Engine``, uses an - ``ANSIDialect`` to create a default ``Compiler``. + ``DefaultDialect`` to create a default ``Compiler``. `parameters` is a dictionary representing the default bind parameters to be used with the statement. If `parameters` is @@ -1175,15 +995,16 @@ class ClauseElement(object): if compiler is None: if dialect is not None: - compiler = dialect.compiler(self, parameters) + compiler = dialect.statement_compiler(dialect, self, parameters) elif bind is not None: - compiler = bind.compiler(self, parameters) + compiler = bind.statement_compiler(self, parameters) elif self.bind is not None: - compiler = self.bind.compiler(self, parameters) + compiler = self.bind.statement_compiler(self, parameters) if compiler is None: - import sqlalchemy.ansisql as ansisql - compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters) + from sqlalchemy.engine.default import DefaultDialect + dialect = DefaultDialect() + compiler = dialect.statement_compiler(dialect, self, parameters=parameters) compiler.compile() return compiler @@ -1727,7 +1548,7 @@ class FromClause(Selectable): def _get_all_embedded_columns(self): ret = [] - class FindCols(ClauseVisitor): + class FindCols(visitors.ClauseVisitor): def visit_column(self, col): ret.append(col) FindCols().traverse(self) @@ -1744,8 +1565,8 @@ class FromClause(Selectable): def replace_selectable(self, old, alias): """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - from sqlalchemy import sql_util - return sql_util.ClauseAdapter(alias).traverse(self, clone=True) + from sqlalchemy.sql import util + return util.ClauseAdapter(alias).traverse(self, clone=True) def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): """Given a ``ColumnElement``, return the exported ``ColumnElement`` @@ -2376,7 +2197,7 @@ class Join(FromClause): else: equivs[x] = util.Set([y]) - class BinaryVisitor(ClauseVisitor): + class BinaryVisitor(visitors.ClauseVisitor): def visit_binary(self, binary): if binary.operator == operators.eq: add_equiv(binary.left, binary.right) @@ -2460,7 +2281,7 @@ class Join(FromClause): return self.__folded_equivalents if equivs is None: equivs = util.Set() - class LocateEquivs(NoColumnVisitor): + class LocateEquivs(visitors.NoColumnVisitor): def visit_binary(self, binary): if binary.operator == operators.eq and binary.left.name == binary.right.name: equivs.add(binary.right) @@ -3331,7 +3152,7 @@ class Select(_SelectBaseMixin, FromClause): return intersect_all(self, other, **kwargs) def _table_iterator(self): - for t in NoColumnVisitor().iterate(self): + for t in visitors.NoColumnVisitor().iterate(self): if isinstance(t, TableClause): yield t diff --git a/lib/sqlalchemy/operators.py b/lib/sqlalchemy/sql/operators.py index b8aca3d26..b8aca3d26 100644 --- a/lib/sqlalchemy/operators.py +++ b/lib/sqlalchemy/sql/operators.py diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql/util.py index cc6325822..2c7294e66 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,7 +1,100 @@ -from sqlalchemy import sql, util, schema, topological +from sqlalchemy import util, schema, topological +from sqlalchemy.sql import expression, visitors """Utility functions that build upon SQL and Schema constructs.""" +class ClauseParameters(object): + """Represent a dictionary/iterator of bind parameter key names/values. + + Tracks the original [sqlalchemy.sql#_BindParamClause] objects as well as the + keys/position of each parameter, and can return parameters as a + dictionary or a list. Will process parameter values according to + the ``TypeEngine`` objects present in the ``_BindParamClause`` instances. + """ + + def __init__(self, dialect, positional=None): + self.dialect = dialect + self.__binds = {} + self.positional = positional or [] + + def get_parameter(self, key): + return self.__binds[key] + + def set_parameter(self, bindparam, value, name): + self.__binds[name] = [bindparam, name, value] + + def get_original(self, key): + return self.__binds[key][2] + + def get_type(self, key): + return self.__binds[key][0].type + + def get_processors(self): + """return a dictionary of bind 'processing' functions""" + return dict([ + (key, value) for key, value in + [( + key, + self.__binds[key][0].bind_processor(self.dialect) + ) for key in self.__binds] + if value is not None + ]) + + def get_processed(self, key, processors): + return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2] + + def keys(self): + return self.__binds.keys() + + def __iter__(self): + return iter(self.keys()) + + def __getitem__(self, key): + (bind, name, value) = self.__binds[key] + processor = bind.bind_processor(self.dialect) + return processor is not None and processor(value) or value + + def __contains__(self, key): + return key in self.__binds + + def set_value(self, key, value): + self.__binds[key][2] = value + + def get_original_dict(self): + return dict([(name, value) for (b, name, value) in self.__binds.values()]) + + def __get_processed(self, key, processors): + if key in processors: + return processors[key](self.__binds[key][2]) + else: + return self.__binds[key][2] + + def get_raw_list(self, processors): + return [self.__get_processed(key, processors) for key in self.positional] + + def get_raw_dict(self, processors, encode_keys=False): + if encode_keys: + return dict([ + ( + key.encode(self.dialect.encoding), + self.__get_processed(key, processors) + ) + for key in self.keys() + ]) + else: + return dict([ + ( + key, + self.__get_processed(key, processors) + ) + for key in self.keys() + ]) + + def __repr__(self): + return self.__class__.__name__ + ":" + repr(self.get_original_dict()) + + + class TableCollection(object): def __init__(self, tables=None): self.tables = tables or [] @@ -64,7 +157,7 @@ class TableCollection(object): return sequence -class TableFinder(TableCollection, sql.NoColumnVisitor): +class TableFinder(TableCollection, visitors.NoColumnVisitor): """locate all Tables within a clause.""" def __init__(self, clause, check_columns=False, include_aliases=False): @@ -85,7 +178,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor): if self.check_columns: self.tables.append(column.table) -class ColumnFinder(sql.ClauseVisitor): +class ColumnFinder(visitors.ClauseVisitor): def __init__(self): self.columns = util.Set() @@ -95,7 +188,7 @@ class ColumnFinder(sql.ClauseVisitor): def __iter__(self): return iter(self.columns) -class ColumnsInClause(sql.ClauseVisitor): +class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns from the clause are in the selectable. """ @@ -108,7 +201,7 @@ class ColumnsInClause(sql.ClauseVisitor): if self.selectable.c.get(column.key) is column: self.result = True -class AbstractClauseProcessor(sql.NoColumnVisitor): +class AbstractClauseProcessor(visitors.NoColumnVisitor): """Traverse a clause and attempt to convert the contents of container elements to a converted element. @@ -224,10 +317,10 @@ class ClauseAdapter(AbstractClauseProcessor): self.equivalents = equivalents def convert_element(self, col): - if isinstance(col, sql.FromClause): + if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): return self.selectable - if not isinstance(col, sql.ColumnElement): + if not isinstance(col, expression.ColumnElement): return None if self.include is not None: if col not in self.include: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py new file mode 100644 index 000000000..98e4de6c3 --- /dev/null +++ b/lib/sqlalchemy/sql/visitors.py @@ -0,0 +1,87 @@ +class ClauseVisitor(object): + """A class that knows how to traverse and visit + ``ClauseElements``. + + Calls visit_XXX() methods dynamically generated for each particualr + ``ClauseElement`` subclass encountered. Traversal of a + hierarchy of ``ClauseElements`` is achieved via the + ``traverse()`` method, which is passed the lead + ``ClauseElement``. + + By default, ``ClauseVisitor`` traverses all elements + fully. Options can be specified at the class level via the + ``__traverse_options__`` dictionary which will be passed + to the ``get_children()`` method of each ``ClauseElement``; + these options can indicate modifications to the set of + elements returned, such as to not return column collections + (column_collections=False) or to return Schema-level items + (schema_visitor=True). + + ``ClauseVisitor`` also supports a simultaneous copy-and-traverse + operation, which will produce a copy of a given ``ClauseElement`` + structure while at the same time allowing ``ClauseVisitor`` subclasses + to modify the new structure in-place. + + """ + __traverse_options__ = {} + + def traverse_single(self, obj, **kwargs): + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) + + def iterate(self, obj, stop_on=None): + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + yield t + traversal.insert(0, t) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + + def traverse(self, obj, stop_on=None, clone=False): + if clone: + obj = obj._clone() + + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + traversal.insert(0, t) + if clone: + t._copy_internals() + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + for target in traversal: + v = self + while v is not None: + meth = getattr(v, "visit_%s" % target.__visit_name__, None) + if meth: + meth(target) + v = getattr(v, '_next', None) + return obj + + def chain(self, visitor): + """'chain' an additional ClauseVisitor onto this ClauseVisitor. + + the chained visitor will receive all visit events after this one.""" + tail = self + while getattr(tail, '_next', None) is not None: + tail = tail._next + tail._next = visitor + return self + +class NoColumnVisitor(ClauseVisitor): + """a ClauseVisitor that will not traverse the exported Column + collections on Table, Alias, Select, and CompoundSelect objects + (i.e. their 'columns' or 'c' attribute). + + this is useful because most traversals don't need those columns, or + in the case of DefaultCompiler it traverses them explicitly; so + skipping their traversal here greatly cuts down on method call overhead. + """ + + __traverse_options__ = {'column_collections':False} diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index 46e0a7137..294842854 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -154,7 +154,7 @@ class TypesTest(AssertMixin): table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) + gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None) for col in numeric_table.c: index = int(col.name[1:]) @@ -238,7 +238,7 @@ class TypesTest(AssertMixin): table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) + gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None) for col in charset_table.c: index = int(col.name[1:]) @@ -707,7 +707,7 @@ class SQLTest(AssertMixin): def colspec(c): - return testbase.db.dialect.schemagenerator( + return testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None).get_column_specification(c) if __name__ == "__main__": diff --git a/test/engine/reflection.py b/test/engine/reflection.py index da6f75149..2345a328a 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -2,7 +2,6 @@ import testbase import pickle, StringIO, unicodedata from sqlalchemy import * -import sqlalchemy.ansisql as ansisql from sqlalchemy.exceptions import NoSuchTableError from testlib import * from testlib import engines @@ -686,7 +685,7 @@ class SchemaTest(PersistTest): def foo(s, p=None): buf.write(s) gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo) - gen = gen.dialect.schemagenerator(gen) + gen = gen.dialect.schemagenerator(gen.dialect, gen) gen.traverse(table1) gen.traverse(table2) buf = buf.getvalue() diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py index 0c824d372..2e84c83d1 100644 --- a/test/orm/dynamic.py +++ b/test/orm/dynamic.py @@ -1,7 +1,6 @@ import testbase import operator from sqlalchemy import * -from sqlalchemy import ansisql from sqlalchemy.orm import * from testlib import * from testlib.fixtures import * diff --git a/test/orm/query.py b/test/orm/query.py index e0b8bf4f3..e3f6ed42c 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1,7 +1,8 @@ import testbase import operator +from sqlalchemy.sql import compiler from sqlalchemy import * -from sqlalchemy import ansisql +from sqlalchemy.engine import default from sqlalchemy.orm import * from testlib import * from testlib.fixtures import * @@ -141,7 +142,7 @@ class OperatorTest(QueryTest): """test sql.Comparator implementation for MapperProperties""" def _test(self, clause, expected): - c = str(clause.compile(dialect=ansisql.ANSIDialect())) + c = str(clause.compile(dialect = default.DefaultDialect())) assert c == expected, "%s != %s" % (c, expected) def test_arithmetic(self): @@ -182,7 +183,7 @@ class OperatorTest(QueryTest): # the compiled clause should match either (e.g.): # 'a' < 'b' -or- 'b' > 'a'. - compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect())) + compiled = str(py_op(lhs, rhs).compile(dialect=default.DefaultDialect())) fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) @@ -201,7 +202,7 @@ class OperatorTest(QueryTest): # this one would require adding compile() to InstrumentedScalarAttribute. do we want this ? #(User.id, "users.id") ): - c = expr.compile(dialect=ansisql.ANSIDialect()) + c = expr.compile(dialect=default.DefaultDialect()) assert str(c) == compare, "%s != %s" % (str(c), compare) diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 3120185d5..a8b642b9b 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -179,7 +179,7 @@ class ConstraintTest(AssertMixin): capt.append(repr(context.parameters)) ex(context) connection._Connection__execute = proxy - schemagen = testbase.db.dialect.schemagenerator(connection) + schemagen = testbase.db.dialect.schemagenerator(testbase.db.dialect, connection) schemagen.traverse(events) assert capt[0].strip().startswith('CREATE TABLE events') diff --git a/test/sql/generative.py b/test/sql/generative.py index d79af577f..9bd97b305 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -1,6 +1,7 @@ import testbase from sqlalchemy import * from testlib import * +from sqlalchemy.sql.visitors import * class TraversalTest(AssertMixin): """test ClauseVisitor's traversal, particularly its ability to copy and modify @@ -213,7 +214,7 @@ class ClauseTest(SQLCompileTest): self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2") def test_clause_adapter(self): - from sqlalchemy import sql_util + from sqlalchemy.sql import util as sql_util t1alias = t1.alias('t1alias') diff --git a/test/sql/labels.py b/test/sql/labels.py index 553a3a3bc..dee76428d 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -78,7 +78,7 @@ class LongLabelsTest(PersistTest): # this is the test that fails if the "max identifier length" is shorter than the # length of the actual columns created, because the column names get truncated. # if you try to separate "physical columns" from "labels", and only truncate the labels, - # the ansisql.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code, + # the compiler.DefaultCompiler.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code, # since it is creating "labels" on the fly but not affecting derived columns, which think they are # still "physical" q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias('foo') diff --git a/test/sql/query.py b/test/sql/query.py index 4f569f1c0..8ec3190b4 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -1,7 +1,8 @@ import testbase import datetime from sqlalchemy import * -from sqlalchemy import exceptions, ansisql +from sqlalchemy import exceptions +from sqlalchemy.engine import default from testlib import * @@ -166,14 +167,14 @@ class QueryTest(PersistTest): assert len(r) == 1 def test_bindparam_detection(self): - dialect = ansisql.ANSIDialect(default_paramstyle='qmark') - prep = lambda q: dialect.compile(sql.text(q)).string + dialect = default.DefaultDialect(default_paramstyle='qmark') + prep = lambda q: str(sql.text(q).compile(dialect=dialect)) def a_eq(got, wanted): if got != wanted: print "Wanted %s" % wanted print "Received %s" % got - self.assert_(got == wanted) + self.assert_(got == wanted, got) a_eq(prep('select foo'), 'select foo') a_eq(prep("time='12:30:00'"), "time='12:30:00'") diff --git a/test/sql/quote.py b/test/sql/quote.py index ad25619df..0c414af3a 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -1,7 +1,7 @@ import testbase from sqlalchemy import * from testlib import * - +from sqlalchemy.sql import compiler class QuoteTest(PersistTest): def setUpAll(self): @@ -98,10 +98,10 @@ class QuoteTest(PersistTest): class PreparerTest(PersistTest): - """Test the db-agnostic quoting services of ANSIIdentifierPreparer.""" + """Test the db-agnostic quoting services of IdentifierPreparer.""" def test_unformat(self): - prep = ansisql.ANSIIdentifierPreparer(None) + prep = compiler.IdentifierPreparer(None) unformat = prep.unformat_identifiers def a_eq(have, want): @@ -120,7 +120,7 @@ class PreparerTest(PersistTest): a_eq(unformat('"foo"."b""a""r"."baz"'), ['foo', 'b"a"r', 'baz']) def test_unformat_custom(self): - class Custom(ansisql.ANSIIdentifierPreparer): + class Custom(compiler.IdentifierPreparer): def __init__(self, dialect): super(Custom, self).__init__(dialect, initial_quote='`', final_quote='`') diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index c497fbcbd..4ffb4c591 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -195,7 +195,7 @@ class ColumnsTest(AssertMixin): ) for aCol in testTable.c: - self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None, None).get_column_specification(aCol)) + self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db.dialect, db, None, None).get_column_specification(aCol)) class UnicodeTest(AssertMixin): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 6830fb63c..58fd7c0d1 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -161,7 +161,8 @@ class ExecutionContextWrapper(object): if params is not None and isinstance(params, list) and len(params) == 1: params = params[0] - if isinstance(ctx.compiled_parameters, sql.ClauseParameters): + from sqlalchemy.sql.util import ClauseParameters + if isinstance(ctx.compiled_parameters, ClauseParameters): parameters = ctx.compiled_parameters.get_original_dict() elif isinstance(ctx.compiled_parameters, list): parameters = [p.get_original_dict() for p in ctx.compiled_parameters] |