summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/databases/access.py8
-rw-r--r--lib/sqlalchemy/databases/firebird.py42
-rw-r--r--lib/sqlalchemy/databases/informix.py41
-rw-r--r--lib/sqlalchemy/databases/mssql.py46
-rw-r--r--lib/sqlalchemy/databases/mysql.py36
-rw-r--r--lib/sqlalchemy/databases/oracle.py56
-rw-r--r--lib/sqlalchemy/databases/postgres.py48
-rw-r--r--lib/sqlalchemy/databases/sqlite.py41
-rw-r--r--lib/sqlalchemy/engine/base.py132
-rw-r--r--lib/sqlalchemy/engine/default.py18
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py6
-rw-r--r--lib/sqlalchemy/orm/interfaces.py3
-rw-r--r--lib/sqlalchemy/orm/mapper.py13
-rw-r--r--lib/sqlalchemy/orm/properties.py3
-rw-r--r--lib/sqlalchemy/orm/query.py8
-rw-r--r--lib/sqlalchemy/orm/strategies.py8
-rw-r--r--lib/sqlalchemy/orm/sync.py6
-rw-r--r--lib/sqlalchemy/orm/util.py8
-rw-r--r--lib/sqlalchemy/schema.py31
-rw-r--r--lib/sqlalchemy/sql/__init__.py3
-rw-r--r--lib/sqlalchemy/sql/compiler.py (renamed from lib/sqlalchemy/ansisql.py)110
-rw-r--r--lib/sqlalchemy/sql/expression.py (renamed from lib/sqlalchemy/sql.py)215
-rw-r--r--lib/sqlalchemy/sql/operators.py (renamed from lib/sqlalchemy/operators.py)0
-rw-r--r--lib/sqlalchemy/sql/util.py (renamed from lib/sqlalchemy/sql_util.py)107
-rw-r--r--lib/sqlalchemy/sql/visitors.py87
25 files changed, 492 insertions, 584 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py
index 6bf8b96e9..4aa773239 100644
--- a/lib/sqlalchemy/databases/access.py
+++ b/lib/sqlalchemy/databases/access.py
@@ -347,7 +347,7 @@ class AccessDialect(ansisql.ANSIDialect):
return names
-class AccessCompiler(ansisql.ANSICompiler):
+class AccessCompiler(compiler.DefaultCompiler):
def visit_select_precolumns(self, select):
"""Access puts TOP, it's version of LIMIT here """
s = select.distinct and "DISTINCT " or ""
@@ -387,7 +387,7 @@ class AccessCompiler(ansisql.ANSICompiler):
return ''
-class AccessSchemaGenerator(ansisql.ANSISchemaGenerator):
+class AccessSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -410,7 +410,7 @@ class AccessSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
-class AccessSchemaDropper(ansisql.ANSISchemaDropper):
+class AccessSchemaDropper(compiler.SchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name))
self.execute()
@@ -418,7 +418,7 @@ class AccessSchemaDropper(ansisql.ANSISchemaDropper):
class AccessDefaultRunner(ansisql.ANSIDefaultRunner):
pass
-class AccessIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class AccessIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index 307fceb48..9cccb53e8 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -7,9 +7,10 @@
import warnings
-from sqlalchemy import util, sql, schema, ansisql, exceptions
-import sqlalchemy.engine.default as default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy.sql import compiler
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
_initialized_kb = False
@@ -99,9 +100,9 @@ class FBExecutionContext(default.DefaultExecutionContext):
return True
-class FBDialect(ansisql.ANSIDialect):
+class FBDialect(default.DefaultDialect):
def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ default.DefaultDialect.__init__(self, **kwargs)
self.type_conv = type_conv
self.concurrency_level= concurrency_level
@@ -135,21 +136,6 @@ class FBDialect(ansisql.ANSIDialect):
def supports_sane_rowcount(self):
return False
- def compiler(self, statement, bindparams, **kwargs):
- return FBCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return FBSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return FBSchemaDropper(self, *args, **kwargs)
-
- def defaultrunner(self, connection):
- return FBDefaultRunner(connection)
-
- def preparer(self):
- return FBIdentifierPreparer(self)
-
def max_identifier_length(self):
return 31
@@ -307,7 +293,7 @@ class FBDialect(ansisql.ANSIDialect):
connection.commit(True)
-class FBCompiler(ansisql.ANSICompiler):
+class FBCompiler(compiler.DefaultCompiler):
"""Firebird specific idiosincrasies"""
def visit_alias(self, alias, asfrom=False, **kwargs):
@@ -346,7 +332,7 @@ class FBCompiler(ansisql.ANSICompiler):
return ""
-class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
+class FBSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -365,13 +351,13 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
self.execute()
-class FBSchemaDropper(ansisql.ANSISchemaDropper):
+class FBSchemaDropper(compiler.SchemaDropper):
def visit_sequence(self, sequence):
self.append("DROP GENERATOR %s" % sequence.name)
self.execute()
-class FBDefaultRunner(ansisql.ANSIDefaultRunner):
+class FBDefaultRunner(base.DefaultRunner):
def exec_default_sql(self, default):
c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection)
return self.connection.execute_compiled(c).scalar()
@@ -421,7 +407,7 @@ RESERVED_WORDS = util.Set(
"whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
-class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class FBIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True)
@@ -430,3 +416,9 @@ class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
dialect = FBDialect
+dialect.statement_compiler = FBCompiler
+dialect.schemagenerator = FBSchemaGenerator
+dialect.schemadropper = FBSchemaDropper
+dialect.defaultrunner = FBDefaultRunner
+dialect.preparer = FBIdentifierPreparer
+
diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py
index 21ecf1538..ceb56903a 100644
--- a/lib/sqlalchemy/databases/informix.py
+++ b/lib/sqlalchemy/databases/informix.py
@@ -7,9 +7,10 @@
import datetime, warnings
-from sqlalchemy import sql, schema, ansisql, exceptions, pool
-import sqlalchemy.engine.default as default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import sql, schema, exceptions, pool
+from sqlalchemy.sql import compiler
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
# for offset
@@ -203,11 +204,11 @@ class InfoExecutionContext(default.DefaultExecutionContext):
def create_cursor( self ):
return informix_cursor( self.connection.connection )
-class InfoDialect(ansisql.ANSIDialect):
+class InfoDialect(default.DefaultDialect):
def __init__(self, use_ansi=True,**kwargs):
self.use_ansi = use_ansi
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ default.DefaultDialect.__init__(self, **kwargs)
self.paramstyle = 'qmark'
def dbapi(cls):
@@ -252,18 +253,6 @@ class InfoDialect(ansisql.ANSIDialect):
def oid_column_name(self,column):
return "rowid"
- def preparer(self):
- return InfoIdentifierPreparer(self)
-
- def compiler(self, statement, bindparams, **kwargs):
- return InfoCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return InfoSchemaGenerator( self , *args, **kwargs)
-
- def schemadropper(self, *args, **params):
- return InfoSchemaDroper( self , *args , **params)
-
def table_names(self, connection, schema):
s = "select tabname from systables"
return [row[0] for row in connection.execute(s)]
@@ -376,14 +365,14 @@ class InfoDialect(ansisql.ANSIDialect):
for cons_name, cons_type, local_column in rows:
table.primary_key.add( table.c[local_column] )
-class InfoCompiler(ansisql.ANSICompiler):
+class InfoCompiler(compiler.DefaultCompiler):
"""Info compiler modifies the lexical structure of Select statements to work under
non-ANSI configured Oracle databases, if the use_ansi flag is False."""
def __init__(self, dialect, statement, parameters=None, **kwargs):
self.limit = 0
self.offset = 0
- ansisql.ANSICompiler.__init__( self , dialect , statement , parameters , **kwargs )
+ compiler.DefaultCompiler.__init__( self , dialect , statement , parameters , **kwargs )
def default_from(self):
return " from systables where tabname = 'systables' "
@@ -416,7 +405,7 @@ class InfoCompiler(ansisql.ANSICompiler):
if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
select.append_column( c )
- return ansisql.ANSICompiler.visit_select(self, select)
+ return compiler.DefaultCompiler.visit_select(self, select)
def limit_clause(self, select):
return ""
@@ -437,7 +426,7 @@ class InfoCompiler(ansisql.ANSICompiler):
elif func.name.lower() in ( 'current_timestamp' , 'now' ):
return "CURRENT YEAR TO SECOND"
else:
- return ansisql.ANSICompiler.visit_function( self , func )
+ return compiler.DefaultCompiler.visit_function( self , func )
def visit_clauselist(self, list):
try:
@@ -446,7 +435,7 @@ class InfoCompiler(ansisql.ANSICompiler):
li = [ c for c in list.clauses ]
return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
-class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
+class InfoSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, first_pk=False):
colspec = self.preparer.format_column(column)
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
@@ -507,7 +496,7 @@ class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
return
super(InfoSchemaGenerator, self).visit_index(index)
-class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class InfoIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
@@ -517,10 +506,14 @@ class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def _requires_quotes(self, value):
return False
-class InfoSchemaDroper(ansisql.ANSISchemaDropper):
+class InfoSchemaDroper(compiler.SchemaDropper):
def drop_foreignkey(self, constraint):
if constraint.name is not None:
super( InfoSchemaDroper , self ).drop_foreignkey( constraint )
dialect = InfoDialect
poolclass = pool.SingletonThreadPool
+dialect.statement_compiler = InfoCompiler
+dialect.schemagenerator = InfoSchemaGenerator
+dialect.schemadropper = InfoSchemaDropper
+dialect.preparer = InfoIdentifierPreparer
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 619e072d9..0caccca95 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -39,10 +39,10 @@ Known issues / TODO:
import datetime, random, warnings, re
-from sqlalchemy import sql, schema, ansisql, exceptions
-import sqlalchemy.types as sqltypes
-from sqlalchemy.engine import default
-import operator, sys
+from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
class MSNumeric(sqltypes.Numeric):
def result_processor(self, dialect):
@@ -366,7 +366,7 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
super(MSSQLExecutionContext_pyodbc, self).post_exec()
-class MSSQLDialect(ansisql.ANSIDialect):
+class MSSQLDialect(default.DefaultDialect):
colspecs = {
sqltypes.Unicode : MSNVarchar,
sqltypes.Integer : MSInteger,
@@ -476,21 +476,6 @@ class MSSQLDialect(ansisql.ANSIDialect):
def supports_sane_rowcount(self):
raise NotImplementedError()
- def compiler(self, statement, bindparams, **kwargs):
- return MSSQLCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return MSSQLSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return MSSQLSchemaDropper(self, *args, **kwargs)
-
- def defaultrunner(self, connection, **kwargs):
- return MSSQLDefaultRunner(connection, **kwargs)
-
- def preparer(self):
- return MSSQLIdentifierPreparer(self)
-
def get_default_schema_name(self, connection):
return self.schema_name
@@ -878,7 +863,7 @@ dialect_mapping = {
}
-class MSSQLCompiler(ansisql.ANSICompiler):
+class MSSQLCompiler(compiler.DefaultCompiler):
def __init__(self, dialect, statement, parameters, **kwargs):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
@@ -931,13 +916,13 @@ class MSSQLCompiler(ansisql.ANSICompiler):
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
- return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+ if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
+ return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
else:
return super(MSSQLCompiler, self).visit_binary(binary)
def label_select_column(self, select, column):
- if isinstance(column, sql._Function):
+ if isinstance(column, expression._Function):
return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
else:
return super(MSSQLCompiler, self).label_select_column(select, column)
@@ -963,7 +948,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
return ""
-class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
+class MSSQLSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -986,7 +971,7 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
-class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
+class MSSQLSchemaDropper(compiler.SchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s.%s" % (
self.preparer.quote_identifier(index.table.name),
@@ -995,11 +980,11 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
self.execute()
-class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
+class MSSQLDefaultRunner(base.DefaultRunner):
# TODO: does ms-sql have standalone sequences ?
pass
-class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
@@ -1012,6 +997,11 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
return value
dialect = MSSQLDialect
+dialect.statement_compiler = MSSQLCompiler
+dialect.schemagenerator = MSSQLSchemaGenerator
+dialect.schemadropper = MSSQLSchemaDropper
+dialect.preparer = MSSQLIdentifierPreparer
+dialect.defaultrunner = MSSQLDefaultRunner
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index d5fd3b6c5..41c6ec70f 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -126,11 +126,12 @@ information affecting MySQL in SQLAlchemy.
import re, datetime, inspect, warnings, sys
from array import array as _array
-from sqlalchemy import ansisql, exceptions, logging, schema, sql, util
-from sqlalchemy import operators as sql_operators
+from sqlalchemy import exceptions, logging, schema, sql, util
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy.sql import compiler
from sqlalchemy.engine import base as engine_base, default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import types as sqltypes
__all__ = (
@@ -1328,13 +1329,17 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
return AUTOCOMMIT_RE.match(self.statement)
-class MySQLDialect(ansisql.ANSIDialect):
+class MySQLDialect(default.DefaultDialect):
"""Details of the MySQL dialect. Not used directly in application code."""
def __init__(self, use_ansiquotes=False, **kwargs):
self.use_ansiquotes = use_ansiquotes
kwargs.setdefault('default_paramstyle', 'format')
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ if self.use_ansiquotes:
+ self.preparer = MySQLANSIIdentifierPreparer
+ else:
+ self.preparer = MySQLIdentifierPreparer
+ default.DefaultDialect.__init__(self, **kwargs)
def dbapi(cls):
import MySQLdb as mysql
@@ -1393,7 +1398,7 @@ class MySQLDialect(ansisql.ANSIDialect):
return True
def compiler(self, statement, bindparams, **kwargs):
- return MySQLCompiler(self, statement, bindparams, **kwargs)
+ return MySQLCompiler(statement, bindparams, dialect=self, **kwargs)
def schemagenerator(self, *args, **kwargs):
return MySQLSchemaGenerator(self, *args, **kwargs)
@@ -1401,12 +1406,6 @@ class MySQLDialect(ansisql.ANSIDialect):
def schemadropper(self, *args, **kwargs):
return MySQLSchemaDropper(self, *args, **kwargs)
- def preparer(self):
- if self.use_ansiquotes:
- return MySQLANSIIdentifierPreparer(self)
- else:
- return MySQLIdentifierPreparer(self)
-
def do_executemany(self, cursor, statement, parameters,
context=None, **kwargs):
rowcount = cursor.executemany(statement, parameters)
@@ -1733,8 +1732,8 @@ class _MySQLPythonRowProxy(object):
return item
-class MySQLCompiler(ansisql.ANSICompiler):
- operators = ansisql.ANSICompiler.operators.copy()
+class MySQLCompiler(compiler.DefaultCompiler):
+ operators = compiler.DefaultCompiler.operators.copy()
operators.update(
{
sql_operators.concat_op: \
@@ -1783,7 +1782,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
# In older versions, the indexes must be created explicitly or the
# creation of foreign key constraints fails."
-class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
+class MySQLSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, override_pk=False,
first_pk=False):
"""Builds column DDL."""
@@ -1827,7 +1826,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
return ' '.join(table_opts)
-class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
+class MySQLSchemaDropper(compiler.SchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s ON %s" %
(self.preparer.format_index(index),
@@ -2368,7 +2367,7 @@ class MySQLSchemaReflector(object):
MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector)
-class _MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
"""MySQL-specific schema identifier configuration."""
def __init__(self, dialect, **kw):
@@ -2433,3 +2432,6 @@ def _re_compile(regex):
return re.compile(regex, re.I | re.UNICODE)
dialect = MySQLDialect
+dialect.statement_compiler = MySQLCompiler
+dialect.schemagenerator = MySQLSchemaGenerator
+dialect.schemadropper = MySQLSchemaDropper
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index a35db1982..2d8f2940f 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -5,11 +5,13 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, warnings, operator, random
+import re, warnings, random
-from sqlalchemy import util, sql, schema, ansisql, exceptions, logging
+from sqlalchemy import util, sql, schema, exceptions, logging
from sqlalchemy.engine import default, base
-import sqlalchemy.types as sqltypes
+from sqlalchemy.sql import compiler, visitors
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
import datetime
@@ -229,9 +231,9 @@ class OracleExecutionContext(default.DefaultExecutionContext):
return base.ResultProxy(self)
-class OracleDialect(ansisql.ANSIDialect):
+class OracleDialect(default.DefaultDialect):
def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs):
- ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs)
+ default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs)
self.use_ansi = use_ansi
self.threaded = threaded
self.allow_twophase = allow_twophase
@@ -333,21 +335,6 @@ class OracleDialect(ansisql.ANSIDialect):
def create_execution_context(self, *args, **kwargs):
return OracleExecutionContext(self, *args, **kwargs)
- def compiler(self, statement, bindparams, **kwargs):
- return OracleCompiler(self, statement, bindparams, **kwargs)
-
- def preparer(self):
- return OracleIdentifierPreparer(self)
-
- def schemagenerator(self, *args, **kwargs):
- return OracleSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return OracleSchemaDropper(self, *args, **kwargs)
-
- def defaultrunner(self, connection, **kwargs):
- return OracleDefaultRunner(connection, **kwargs)
-
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':self._denormalize_name(table_name)})
return bool( cursor.fetchone() is not None )
@@ -560,16 +547,16 @@ class _OuterJoinColumn(sql.ClauseElement):
def __init__(self, column):
self.column = column
-class OracleCompiler(ansisql.ANSICompiler):
+class OracleCompiler(compiler.DefaultCompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
the use_ansi flag is False.
"""
- operators = ansisql.ANSICompiler.operators.copy()
+ operators = compiler.DefaultCompiler.operators.copy()
operators.update(
{
- operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+ sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
}
)
@@ -590,13 +577,13 @@ class OracleCompiler(ansisql.ANSICompiler):
def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+ return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
(where, parentjoin) = self.__wheres.get(join, (None, None))
- class VisitOn(sql.ClauseVisitor):
+ class VisitOn(visitors.ClauseVisitor):
def visit_binary(s, binary):
- if binary.operator == operator.eq:
+ if binary.operator == sql_operators.eq:
if binary.left.table is join.right:
binary.left = _OuterJoinColumn(binary.left)
elif binary.right.table is join.right:
@@ -640,7 +627,7 @@ class OracleCompiler(ansisql.ANSICompiler):
for c in insert.table.primary_key:
if c.key not in self.parameters:
self.parameters[c.key] = None
- return ansisql.ANSICompiler.visit_insert(self, insert)
+ return compiler.DefaultCompiler.visit_insert(self, insert)
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
@@ -672,7 +659,7 @@ class OracleCompiler(ansisql.ANSICompiler):
limitselect.append_whereclause("ora_rn<=%d" % select._limit)
return self.process(limitselect, **kwargs)
else:
- return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
+ return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""
@@ -684,7 +671,7 @@ class OracleCompiler(ansisql.ANSICompiler):
return super(OracleCompiler, self).for_update_clause(select)
-class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
+class OracleSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -701,13 +688,13 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class OracleSchemaDropper(ansisql.ANSISchemaDropper):
+class OracleSchemaDropper(compiler.SchemaDropper):
def visit_sequence(self, sequence):
if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
+class OracleDefaultRunner(base.DefaultRunner):
def exec_default_sql(self, default):
c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
return self.connection.execute(c).scalar()
@@ -715,10 +702,15 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def visit_sequence(self, seq):
return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
-class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
def format_savepoint(self, savepoint):
name = re.sub(r'^_+', '', savepoint.ident)
return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
dialect = OracleDialect
+dialect.statement_compiler = OracleCompiler
+dialect.schemagenerator = OracleSchemaGenerator
+dialect.schemadropper = OracleSchemaDropper
+dialect.preparer = OracleIdentifierPreparer
+dialect.defaultrunner = OracleDefaultRunner
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 74a3ef13f..29d84ad4d 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,11 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, random, warnings, operator
+import re, random, warnings
-from sqlalchemy import sql, schema, ansisql, exceptions, util
+from sqlalchemy import sql, schema, exceptions, util
from sqlalchemy.engine import base, default
-import sqlalchemy.types as sqltypes
+from sqlalchemy.sql import compiler
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
class PGInet(sqltypes.TypeEngine):
@@ -220,9 +222,9 @@ class PGExecutionContext(default.DefaultExecutionContext):
self._last_inserted_ids = [v for v in row]
super(PGExecutionContext, self).post_exec()
-class PGDialect(ansisql.ANSIDialect):
+class PGDialect(default.DefaultDialect):
def __init__(self, use_oids=False, server_side_cursors=False, **kwargs):
- ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
+ default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
self.paramstyle = 'pyformat'
@@ -249,15 +251,6 @@ class PGDialect(ansisql.ANSIDialect):
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
- def compiler(self, statement, bindparams, **kwargs):
- return PGCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return PGSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return PGSchemaDropper(self, *args, **kwargs)
-
def do_begin_twophase(self, connection, xid):
self.do_begin(connection.connection)
@@ -286,12 +279,6 @@ class PGDialect(ansisql.ANSIDialect):
resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
return [row[0] for row in resultset]
- def defaultrunner(self, context, **kwargs):
- return PGDefaultRunner(context, **kwargs)
-
- def preparer(self):
- return PGIdentifierPreparer(self)
-
def get_default_schema_name(self, connection):
if not hasattr(self, '_default_schema_name'):
self._default_schema_name = connection.scalar("select current_schema()", None)
@@ -556,11 +543,11 @@ class PGDialect(ansisql.ANSIDialect):
-class PGCompiler(ansisql.ANSICompiler):
- operators = ansisql.ANSICompiler.operators.copy()
+class PGCompiler(compiler.DefaultCompiler):
+ operators = compiler.DefaultCompiler.operators.copy()
operators.update(
{
- operator.mod : '%%'
+ sql_operators.mod : '%%'
}
)
@@ -597,7 +584,7 @@ class PGCompiler(ansisql.ANSICompiler):
else:
return super(PGCompiler, self).for_update_clause(select)
-class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
+class PGSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
@@ -620,13 +607,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class PGSchemaDropper(ansisql.ANSISchemaDropper):
+class PGSchemaDropper(compiler.SchemaDropper):
def visit_sequence(self, sequence):
if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-class PGDefaultRunner(ansisql.ANSIDefaultRunner):
+class PGDefaultRunner(base.DefaultRunner):
def get_column_default(self, column, isinsert=True):
if column.primary_key:
# passive defaults on primary keys have to be overridden
@@ -642,7 +629,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
return self.connection.execute(exc).scalar()
- return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
+ return super(PGDefaultRunner, self).get_column_default(column)
def visit_sequence(self, seq):
if not seq.optional:
@@ -650,7 +637,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
else:
return None
-class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
def _fold_identifier_case(self, value):
return value.lower()
@@ -660,3 +647,8 @@ class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
return value
dialect = PGDialect
+dialect.statement_compiler = PGCompiler
+dialect.schemagenerator = PGSchemaGenerator
+dialect.schemadropper = PGSchemaDropper
+dialect.preparer = PGIdentifierPreparer
+dialect.defaultrunner = PGDefaultRunner
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index d96422236..c2aced4d0 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -7,11 +7,12 @@
import re
-from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault
-import sqlalchemy.engine.default as default
+from sqlalchemy import schema, exceptions, pool, PassiveDefault
+from sqlalchemy.engine import default
import sqlalchemy.types as sqltypes
import datetime,time, warnings
import sqlalchemy.util as util
+from sqlalchemy.sql import compiler
SELECT_REGEXP = re.compile(r'\s*(?:SELECT|PRAGMA)', re.I | re.UNICODE)
@@ -172,10 +173,10 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
def is_select(self):
return SELECT_REGEXP.match(self.statement)
-class SQLiteDialect(ansisql.ANSIDialect):
+class SQLiteDialect(default.DefaultDialect):
def __init__(self, **kwargs):
- ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
+ default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs)
def vers(num):
return tuple([int(x) for x in num.split('.')])
if self.dbapi is not None:
@@ -195,24 +196,12 @@ class SQLiteDialect(ansisql.ANSIDialect):
return sqlite
dbapi = classmethod(dbapi)
- def compiler(self, statement, bindparams, **kwargs):
- return SQLiteCompiler(self, statement, bindparams, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return SQLiteSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return SQLiteSchemaDropper(self, *args, **kwargs)
-
def server_version_info(self, connection):
return self.dbapi.sqlite_version_info
def supports_alter(self):
return False
- def preparer(self):
- return SQLiteIdentifierPreparer(self)
-
def create_connect_args(self, url):
filename = url.database or ':memory:'
@@ -255,7 +244,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
return (row is not None)
def reflecttable(self, connection, table, include_columns):
- c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {})
+ c = connection.execute("PRAGMA table_info(%s)" % self.identifier_preparer.format_table(table), {})
found_table = False
while True:
row = c.fetchone()
@@ -295,7 +284,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
if not found_table:
raise exceptions.NoSuchTableError(table.name)
- c = connection.execute("PRAGMA foreign_key_list(%s)" % self.preparer().format_table(table), {})
+ c = connection.execute("PRAGMA foreign_key_list(%s)" % self.identifier_preparer.format_table(table), {})
fks = {}
while True:
row = c.fetchone()
@@ -324,7 +313,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
for name, value in fks.iteritems():
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1]))
# check for UNIQUE indexes
- c = connection.execute("PRAGMA index_list(%s)" % self.preparer().format_table(table), {})
+ c = connection.execute("PRAGMA index_list(%s)" % self.identifier_preparer.format_table(table), {})
unique_indexes = []
while True:
row = c.fetchone()
@@ -343,7 +332,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
cols.append(row[2])
col = table.columns[row[2]]
-class SQLiteCompiler(ansisql.ANSICompiler):
+class SQLiteCompiler(compiler.DefaultCompiler):
def visit_cast(self, cast):
if self.dialect.supports_cast:
return super(SQLiteCompiler, self).visit_cast(cast)
@@ -369,7 +358,8 @@ class SQLiteCompiler(ansisql.ANSICompiler):
# sqlite has no "FOR UPDATE" AFAICT
return ''
-class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
+
+class SQLiteSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -391,12 +381,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
# else:
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
-class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
+class SQLiteSchemaDropper(compiler.SchemaDropper):
pass
-class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
dialect = SQLiteDialect
dialect.poolclass = pool.SingletonThreadPool
+dialect.statement_compiler = SQLiteCompiler
+dialect.schemagenerator = SQLiteSchemaGenerator
+dialect.schemadropper = SQLiteSchemaDropper
+dialect.preparer = SQLiteIdentifierPreparer
+
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index faeb00cc9..553c8df84 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -8,7 +8,8 @@
higher-level statement-construction, connection-management,
execution and result contexts."""
-from sqlalchemy import exceptions, sql, schema, util, types, logging
+from sqlalchemy import exceptions, schema, util, types, logging
+from sqlalchemy.sql import expression, visitors
import StringIO, sys
@@ -35,6 +36,22 @@ class Dialect(object):
encoding
type of encoding to use for unicode, usually defaults to 'utf-8'
+
+ schemagenerator
+ a [sqlalchemy.schema#SchemaVisitor] class which generates schemas.
+
+ schemadropper
+ a [sqlalchemy.schema#SchemaVisitor] class which drops schemas.
+
+ defaultrunner
+ a [sqlalchemy.schema#SchemaVisitor] class which executes defaults.
+
+ statement_compiler
+ a [sqlalchemy.engine.base#Compiled] class used to compile SQL statements
+
+ preparer
+ a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote
+ identifiers.
"""
def create_connect_args(self, url):
@@ -105,48 +122,6 @@ class Dialect(object):
raise NotImplementedError()
- def schemagenerator(self, connection, **kwargs):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas.
-
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
-
- `schemagenerator()` is called via the `create()` method on Table,
- Index, and others.
- """
-
- raise NotImplementedError()
-
- def schemadropper(self, connection, **kwargs):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas.
-
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
-
- `schemadropper()` is called via the `drop()` method on Table,
- Index, and others.
- """
-
- raise NotImplementedError()
-
- def defaultrunner(self, execution_context):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
-
- execution_context
- a [sqlalchemy.engine#ExecutionContext] to use for statement execution
-
- """
-
- raise NotImplementedError()
-
- def compiler(self, statement, parameters):
- """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters.
-
- The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler].
-
- """
-
- raise NotImplementedError()
def server_version_info(self, connection):
"""Return a tuple of the database's version number."""
@@ -266,16 +241,6 @@ class Dialect(object):
raise NotImplementedError()
-
- def compile(self, clauseelement, parameters=None):
- """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect.
-
- Returns [sqlalchemy.sql#Compiled]. A convenience method which
- flips around the compile() call on ``ClauseElement``.
- """
-
- return clauseelement.compile(dialect=self, parameters=parameters)
-
def is_disconnect(self, e):
"""Return True if the given DBAPI error indicates an invalid connection"""
@@ -304,7 +269,7 @@ class ExecutionContext(object):
DBAPI cursor procured from the connection
compiled
- if passed to constructor, sql.Compiled object being executed
+ if passed to constructor, sqlalchemy.engine.base.Compiled object being executed
statement
string version of the statement to be executed. Is either
@@ -439,6 +404,9 @@ class Compiled(object):
def __init__(self, dialect, statement, parameters, bind=None):
"""Construct a new ``Compiled`` object.
+ dialect
+ ``Dialect`` to compile against.
+
statement
``ClauseElement`` to be compiled.
@@ -724,8 +692,8 @@ class Connection(Connectable):
def scalar(self, object, *multiparams, **params):
return self.execute(object, *multiparams, **params).scalar()
- def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+ def statement_compiler(self, statement, parameters, **kwargs):
+ return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
@@ -822,9 +790,9 @@ class Connection(Connectable):
# poor man's multimethod/generic function thingy
executors = {
- sql._Function : _execute_function,
- sql.ClauseElement : _execute_clauseelement,
- sql.ClauseVisitor : _execute_compiled,
+ expression._Function : _execute_function,
+ expression.ClauseElement : _execute_clauseelement,
+ visitors.ClauseVisitor : _execute_compiled,
schema.SchemaItem:_execute_default,
str.__mro__[-2] : _execute_text
}
@@ -989,14 +957,14 @@ class Engine(Connectable):
connection.close()
def _func(self):
- return sql._FunctionGenerator(bind=self)
+ return expression._FunctionGenerator(bind=self)
func = property(_func)
def text(self, text, *args, **kwargs):
"""Return a sql.text() object for performing literal queries."""
- return sql.text(text, bind=self, *args, **kwargs)
+ return expression.text(text, bind=self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
if connection is None:
@@ -1004,7 +972,7 @@ class Engine(Connectable):
else:
conn = connection
try:
- visitorcallable(conn, **kwargs).traverse(element)
+ visitorcallable(self.dialect, conn, **kwargs).traverse(element)
finally:
if connection is None:
conn.close()
@@ -1057,8 +1025,8 @@ class Engine(Connectable):
connection = self.contextual_connect(close_with_result=True)
return connection._execute_compiled(compiled, multiparams, params)
- def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+ def statement_compiler(self, statement, parameters, **kwargs):
+ return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
@@ -1159,6 +1127,7 @@ class ResultProxy(object):
self.closed = False
self.cursor = context.cursor
self.__echo = logging.is_debug_enabled(context.engine.logger)
+ self._process_row = self._row_processor()
if context.is_select():
self._init_metadata()
self._rowcount = None
@@ -1222,7 +1191,7 @@ class ResultProxy(object):
rec = props[key]
elif isinstance(key, basestring) and key.lower() in props:
rec = props[key.lower()]
- elif isinstance(key, sql.ColumnElement):
+ elif isinstance(key, expression.ColumnElement):
label = context.column_labels.get(key._label, key.name).lower()
if label in props:
rec = props[label]
@@ -1320,21 +1289,21 @@ class ResultProxy(object):
return self.cursor.fetchmany(size)
def _fetchall_impl(self):
return self.cursor.fetchall()
+
+ def _row_processor(self):
+ return RowProxy
- def _process_row(self, row):
- return RowProxy(self, row)
-
def fetchall(self):
"""Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
- l = [self._process_row(row) for row in self._fetchall_impl()]
+ l = [self._process_row(self, row) for row in self._fetchall_impl()]
self.close()
return l
def fetchmany(self, size=None):
"""Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
- l = [self._process_row(row) for row in self._fetchmany_impl(size)]
+ l = [self._process_row(self, row) for row in self._fetchmany_impl(size)]
if len(l) == 0:
self.close()
return l
@@ -1343,7 +1312,7 @@ class ResultProxy(object):
"""Fetch one row, just like DBAPI ``cursor.fetchone()``."""
row = self._fetchone_impl()
if row is not None:
- return self._process_row(row)
+ return self._process_row(self, row)
else:
self.close()
return None
@@ -1353,7 +1322,7 @@ class ResultProxy(object):
row = self._fetchone_impl()
try:
if row is not None:
- return self._process_row(row)[0]
+ return self._process_row(self, row)[0]
else:
return None
finally:
@@ -1425,11 +1394,9 @@ class BufferedColumnResultProxy(ResultProxy):
def _get_col(self, row, key):
rec = self._key_cache[key]
return row[rec[2]]
-
- def _process_row(self, row):
- sup = super(BufferedColumnResultProxy, self)
- row = [sup._get_col(row, i) for i in xrange(len(row))]
- return RowProxy(self, row)
+
+ def _row_processor(self):
+ return BufferedColumnRow
def fetchall(self):
l = []
@@ -1523,6 +1490,11 @@ class RowProxy(object):
def __len__(self):
return len(self.__row)
+class BufferedColumnRow(RowProxy):
+ def __init__(self, parent, row):
+ row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))]
+ super(BufferedColumnRow, self).__init__(parent, row)
+
class SchemaIterator(schema.SchemaVisitor):
"""A visitor that can gather text into a buffer and execute the contents of the buffer."""
@@ -1590,11 +1562,11 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def exec_default_sql(self, default):
- c = sql.select([default.arg]).compile(bind=self.connection)
+ c = expression.select([default.arg]).compile(bind=self.connection)
return self.connection._execute_compiled(c).scalar()
def visit_column_onupdate(self, onupdate):
- if isinstance(onupdate.arg, sql.ClauseElement):
+ if isinstance(onupdate.arg, expression.ClauseElement):
return self.exec_default_sql(onupdate)
elif callable(onupdate.arg):
return onupdate.arg(self.context)
@@ -1602,7 +1574,7 @@ class DefaultRunner(schema.SchemaVisitor):
return onupdate.arg
def visit_column_default(self, default):
- if isinstance(default.arg, sql.ClauseElement):
+ if isinstance(default.arg, expression.ClauseElement):
return self.exec_default_sql(default)
elif callable(default.arg):
return default.arg(self.context)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index ccaf080e7..059395921 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -9,7 +9,7 @@
from sqlalchemy import schema, exceptions, sql, util
import re, random
from sqlalchemy.engine import base
-
+from sqlalchemy.sql import compiler, expression
AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
re.I | re.UNICODE)
@@ -18,6 +18,12 @@ SELECT_REGEXP = re.compile(r'\s*SELECT', re.I | re.UNICODE)
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
+ schemagenerator = compiler.SchemaGenerator
+ schemadropper = compiler.SchemaDropper
+ statement_compiler = compiler.DefaultCompiler
+ preparer = compiler.IdentifierPreparer
+ defaultrunner = base.DefaultRunner
+
def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
self.convert_unicode = convert_unicode
self.encoding = encoding
@@ -25,6 +31,7 @@ class DefaultDialect(base.Dialect):
self._ischema = None
self.dbapi = dbapi
self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
+ self.identifier_preparer = self.preparer(self)
def dbapi_type_map(self):
# most DBAPIs have problems with this (such as, psycocpg2 types
@@ -46,6 +53,7 @@ class DefaultDialect(base.Dialect):
typeobj = typeobj()
return typeobj
+
def supports_unicode_statements(self):
"""indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
return False
@@ -96,13 +104,13 @@ class DefaultDialect(base.Dialect):
return "_sa_%032x" % random.randint(0,2**128)
def do_savepoint(self, connection, name):
- connection.execute(sql.SavepointClause(name))
+ connection.execute(expression.SavepointClause(name))
def do_rollback_to_savepoint(self, connection, name):
- connection.execute(sql.RollbackToSavepointClause(name))
+ connection.execute(expression.RollbackToSavepointClause(name))
def do_release_savepoint(self, connection, name):
- connection.execute(sql.ReleaseSavepointClause(name))
+ connection.execute(expression.ReleaseSavepointClause(name))
def do_executemany(self, cursor, statement, parameters, **kwargs):
cursor.executemany(statement, parameters)
@@ -110,8 +118,6 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, context):
- return base.DefaultRunner(context)
def is_disconnect(self, e):
return False
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index 65d2ab3d2..258bddb4a 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -294,7 +294,7 @@ from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
from sqlalchemy.exceptions import *
-
+from sqlalchemy.sql import expression
_testsql = """
CREATE TABLE books (
@@ -415,7 +415,7 @@ def _selectable_name(selectable):
return x
def class_for_table(selectable, **mapper_kwargs):
- selectable = sql._selectable(selectable)
+ selectable = expression._selectable(selectable)
mapname = 'Mapped' + _selectable_name(selectable)
if isinstance(selectable, Table):
klass = TableClassType(mapname, (object,), {})
@@ -499,7 +499,7 @@ class SqlSoup:
def with_labels(self, item):
# TODO give meaningful aliases
- return self.map(sql._selectable(item).select(use_labels=True).alias('foo'))
+ return self.map(expression._selectable(item).select(use_labels=True).alias('foo'))
def join(self, *args, **kwargs):
j = join(*args, **kwargs)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index c54eee438..9000a8df5 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -6,6 +6,7 @@
from sqlalchemy import util, logging, sql
+from sqlalchemy.sql import expression
__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
'MapperProperty', 'PropComparator', 'StrategizedProperty',
@@ -363,7 +364,7 @@ class MapperProperty(object):
return operator(self.comparator, value)
-class PropComparator(sql.ColumnOperators):
+class PropComparator(expression.ColumnOperators):
"""defines comparison operations for MapperProperty objects"""
def expression_element(self):
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index b48368417..60d4526ec 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -6,7 +6,8 @@
import weakref, warnings, operator
from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy import sql_util as sqlutil
+from sqlalchemy.sql import expression
+from sqlalchemy.sql import util as sqlutil
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.util import ExtensionCarrier
from sqlalchemy.orm import sync
@@ -77,7 +78,7 @@ class Mapper(object):
raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
for table in (local_table, select_table):
- if table is not None and isinstance(table, sql._SelectBaseMixin):
+ if table is not None and isinstance(table, expression._SelectBaseMixin):
# some db's, noteably postgres, dont want to select from a select
# without an alias. also if we make our own alias internally, then
# the configured properties on the mapper are not matched against the alias
@@ -438,7 +439,7 @@ class Mapper(object):
# against the "mapped_table" of this mapper.
equivalent_columns = self._get_equivalent_columns()
- primary_key = sql.ColumnSet()
+ primary_key = expression.ColumnSet()
for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
c = self.mapped_table.corresponding_column(col, raiseerr=False)
@@ -644,9 +645,9 @@ class Mapper(object):
props = {}
if self.properties is not None:
for key, prop in self.properties.iteritems():
- if sql.is_column(prop):
+ if expression.is_column(prop):
props[key] = self.select_table.corresponding_column(prop)
- elif (isinstance(prop, list) and sql.is_column(prop[0])):
+ elif (isinstance(prop, list) and expression.is_column(prop[0])):
props[key] = [self.select_table.corresponding_column(c) for c in prop]
self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument)
@@ -768,7 +769,7 @@ class Mapper(object):
def _create_prop_from_column(self, column):
column = util.to_list(column)
- if not sql.is_column(column[0]):
+ if not expression.is_column(column[0]):
return None
mapped_column = []
for c in column:
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 670fcccc9..20cbcb235 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -11,7 +11,8 @@ operations. PropertyLoader also relies upon the dependency.py module
to handle flush-time dependency sorting and processing.
"""
-from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
+from sqlalchemy import sql, schema, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 44329468f..5cbe19ce2 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -4,7 +4,9 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import expression, visitors
from sqlalchemy.orm import mapper, object_mapper
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.interfaces import OperationContext, LoaderStack
@@ -312,7 +314,7 @@ class Query(object):
clause = self._from_obj[-1]
currenttables = [clause]
- class FindJoinedTables(sql.NoColumnVisitor):
+ class FindJoinedTables(visitors.NoColumnVisitor):
def visit_join(self, join):
currenttables.append(join.left)
currenttables.append(join.right)
@@ -836,7 +838,7 @@ class Query(object):
# if theres an order by, add those columns to the column list
# of the "rowcount" query we're going to make
if order_by:
- order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []]
+ order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
cf = sql_util.ColumnFinder()
for o in order_by:
cf.traverse(o)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 6565c8d77..bdb17e1d6 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -6,7 +6,9 @@
"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import visitors
from sqlalchemy.orm import mapper, attributes
from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
from sqlalchemy.orm import session as sessionlib
@@ -292,7 +294,7 @@ class LazyLoader(AbstractRelationLoader):
(criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
- class Visitor(sql.ClauseVisitor):
+ class Visitor(visitors.ClauseVisitor):
def visit_bindparam(s, bindparam):
mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
if bindparam.key in bind_to_col:
@@ -396,7 +398,7 @@ class LazyLoader(AbstractRelationLoader):
if not isinstance(expr, sql.ColumnElement):
return None
columns = []
- class FindColumnInColumnClause(sql.ClauseVisitor):
+ class FindColumnInColumnClause(visitors.ClauseVisitor):
def visit_column(self, c):
columns.append(c)
FindColumnInColumnClause().traverse(expr)
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index cf48202b0..49661a95e 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -10,9 +10,9 @@ clause that compares column values.
"""
from sqlalchemy import sql, schema, exceptions
+from sqlalchemy.sql import visitors, operators
from sqlalchemy import logging
from sqlalchemy.orm import util as mapperutil
-import operator
ONETOMANY = 0
MANYTOONE = 1
@@ -43,7 +43,7 @@ class ClauseSynchronizer(object):
def compile_binary(binary):
"""Assemble a SyncRule given a single binary condition."""
- if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
source_column = None
@@ -144,7 +144,7 @@ class SyncRule(object):
SyncRule.logger = logging.class_logger(SyncRule)
-class BinaryVisitor(sql.ClauseVisitor):
+class BinaryVisitor(visitors.ClauseVisitor):
def __init__(self, func):
self.func = func
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 6b0956dc4..b3f58c954 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -4,7 +4,9 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions, sql_util
+from sqlalchemy import sql, util, exceptions
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import visitors
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
@@ -161,7 +163,7 @@ class ExtensionCarrier(MapperExtension):
before_delete = _create_do('before_delete')
after_delete = _create_do('after_delete')
-class BinaryVisitor(sql.ClauseVisitor):
+class BinaryVisitor(visitors.ClauseVisitor):
def __init__(self, func):
self.func = func
@@ -196,7 +198,7 @@ class AliasedClauses(object):
# for column-level subqueries, swap out its selectable with our
# eager version as appropriate, and manually build the
# "correlation" list of the subquery.
- class ModifySubquery(sql.ClauseVisitor):
+ class ModifySubquery(visitors.ClauseVisitor):
def visit_select(s, select):
select._should_correlate = False
select.append_correlation(self.alias)
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 99803d665..99ca2389b 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -18,8 +18,11 @@ objects as well as the visitor interface, so that the schema package
"""
import re, inspect
-from sqlalchemy import sql, types, exceptions, util, databases
+from sqlalchemy import types, exceptions, util, databases
+from sqlalchemy.sql import expression, visitors
import sqlalchemy
+
+
URL = None
__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
@@ -31,7 +34,7 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
class SchemaItem(object):
"""Base class for items that define a database schema."""
- __metaclass__ = sql._FigureVisitName
+ __metaclass__ = expression._FigureVisitName
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
@@ -84,7 +87,7 @@ def _get_table_key(name, schema):
else:
return schema + "." + name
-class _TableSingleton(sql._FigureVisitName):
+class _TableSingleton(expression._FigureVisitName):
"""A metaclass used by the ``Table`` object to provide singleton behavior."""
def __call__(self, name, metadata, *args, **kwargs):
@@ -124,10 +127,10 @@ class _TableSingleton(sql._FigureVisitName):
return table
-class Table(SchemaItem, sql.TableClause):
+class Table(SchemaItem, expression.TableClause):
"""Represent a relational database table.
- This subclasses ``sql.TableClause`` to provide a table that is
+ This subclasses ``expression.TableClause`` to provide a table that is
associated with an instance of ``MetaData``, which in turn
may be associated with an instance of ``Engine``.
@@ -229,7 +232,7 @@ class Table(SchemaItem, sql.TableClause):
self.schema = kwargs.pop('schema', None)
self.indexes = util.Set()
self.constraints = util.Set()
- self._columns = sql.ColumnCollection()
+ self._columns = expression.ColumnCollection()
self.primary_key = PrimaryKeyConstraint()
self._foreign_keys = util.OrderedSet()
self.quote = kwargs.pop('quote', False)
@@ -291,7 +294,7 @@ class Table(SchemaItem, sql.TableClause):
def get_children(self, column_collections=True, schema_visitor=False, **kwargs):
if not schema_visitor:
- return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs)
+ return expression.TableClause.get_children(self, column_collections=column_collections, **kwargs)
else:
if column_collections:
return [c for c in self.columns]
@@ -338,10 +341,10 @@ class Table(SchemaItem, sql.TableClause):
args.append(c.copy())
return Table(self.name, metadata, schema=schema, *args)
-class Column(SchemaItem, sql._ColumnClause):
+class Column(SchemaItem, expression._ColumnClause):
"""Represent a column in a database table.
- This is a subclass of ``sql.ColumnClause`` and represents an
+ This is a subclass of ``expression.ColumnClause`` and represents an
actual existing table in the database, in a similar fashion as
``TableClause``/``Table``.
"""
@@ -575,7 +578,7 @@ class Column(SchemaItem, sql._ColumnClause):
return [x for x in (self.default, self.onupdate) if x is not None] + \
list(self.foreign_keys) + list(self.constraints)
else:
- return sql._ColumnClause.get_children(self, **kwargs)
+ return expression._ColumnClause.get_children(self, **kwargs)
class ForeignKey(SchemaItem):
@@ -806,7 +809,7 @@ class Constraint(SchemaItem):
def __init__(self, name=None):
self.name = name
- self.columns = sql.ColumnCollection()
+ self.columns = expression.ColumnCollection()
def __contains__(self, x):
return self.columns.contains_column(x)
@@ -1124,12 +1127,12 @@ class MetaData(SchemaItem):
del self.tables[table.key]
def table_iterator(self, reverse=True, tables=None):
- import sqlalchemy.sql_util
+ from sqlalchemy.sql import util as sql_util
if tables is None:
tables = self.tables.values()
else:
tables = util.Set(tables).intersection(self.tables.values())
- sorter = sqlalchemy.sql_util.TableCollection(list(tables))
+ sorter = sql_util.TableCollection(list(tables))
return iter(sorter.sort(reverse=reverse))
def _get_parent(self):
@@ -1356,7 +1359,7 @@ class ThreadLocalMetaData(MetaData):
e.dispose()
-class SchemaVisitor(sql.ClauseVisitor):
+class SchemaVisitor(visitors.ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
__traverse_options__ = {'schema_visitor':True}
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
new file mode 100644
index 000000000..06b9f1f75
--- /dev/null
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -0,0 +1,3 @@
+from sqlalchemy.sql.expression import *
+from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor
+
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/sql/compiler.py
index 5f5e1c171..6053c72be 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1,22 +1,18 @@
-# ansisql.py
+# compiler.py
# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Defines ANSI SQL operations.
+"""SQL expression compilation routines and DDL implementations."""
-Contains default implementations for the abstract objects in the sql
-module.
-"""
+import string, re
+from sqlalchemy import schema, engine, util, exceptions
+from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import expression as sql
-import string, re, sets, operator
-
-from sqlalchemy import schema, sql, engine, util, exceptions, operators
-from sqlalchemy.engine import default
-
-
-ANSI_FUNCS = sets.ImmutableSet([
+ANSI_FUNCS = util.Set([
'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
'SESSION_USER', 'USER'])
@@ -77,7 +73,6 @@ OPERATORS = {
operators.comma_op : ', ',
operators.desc_op : 'DESC',
operators.asc_op : 'ASC',
-
operators.from_ : 'FROM',
operators.as_ : 'AS',
operators.exists : 'EXISTS',
@@ -85,36 +80,10 @@ OPERATORS = {
operators.isnot : 'IS NOT'
}
-class ANSIDialect(default.DefaultDialect):
- def __init__(self, cache_identifiers=True, **kwargs):
- super(ANSIDialect,self).__init__(**kwargs)
- self.identifier_preparer = self.preparer()
- self.cache_identifiers = cache_identifiers
-
- def create_connect_args(self):
- return ([],{})
-
- def schemagenerator(self, *args, **kwargs):
- return ANSISchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return ANSISchemaDropper(self, *args, **kwargs)
-
- def compiler(self, statement, parameters, **kwargs):
- return ANSICompiler(self, statement, parameters, **kwargs)
-
- def preparer(self):
- """Return an IdentifierPreparer.
-
- This object is used to format table and column names including
- proper quoting and case conventions.
- """
- return ANSIIdentifierPreparer(self)
-
-class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
+class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
"""Default implementation of Compiled.
- Compiles ClauseElements into ANSI-compliant SQL strings.
+ Compiles ClauseElements into SQL strings.
"""
__traverse_options__ = {'column_collections':False, 'entry':True}
@@ -122,7 +91,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
operators = OPERATORS
def __init__(self, dialect, statement, parameters=None, **kwargs):
- """Construct a new ``ANSICompiler`` object.
+ """Construct a new ``DefaultCompiler`` object.
dialect
Dialect to be used
@@ -139,7 +108,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
correspond to the keys present in the parameters.
"""
- super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
+ super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs)
# if we are insert/update. set to true when we visit an INSERT or UPDATE
self.isinsert = self.isupdate = False
@@ -170,17 +139,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
self.bindtemplate = ":%s"
# paramstyle from the dialect (comes from DBAPI)
- self.paramstyle = dialect.paramstyle
+ self.paramstyle = self.dialect.paramstyle
# true if the paramstyle is positional
- self.positional = dialect.positional
+ self.positional = self.dialect.positional
# a list of the compiled's bind parameter names, used to help
# formulate a positional argument list
self.positiontup = []
- # an ANSIIdentifierPreparer that formats the quoting of identifiers
- self.preparer = dialect.identifier_preparer
+ # an IdentifierPreparer that formats the quoting of identifiers
+ self.preparer = self.dialect.identifier_preparer
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
@@ -244,7 +213,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
return None
def construct_params(self, params):
- """Return a sql.ClauseParameters object.
+ """Return a sql.util.ClauseParameters object.
Combines the given bind parameter dictionary (string keys to object values)
with the _BindParamClause objects stored within this Compiled object
@@ -252,7 +221,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
for a single statement execution, or one element of an executemany execution.
"""
- d = sql.ClauseParameters(self.dialect, self.positiontup)
+ d = sql_util.ClauseParameters(self.dialect, self.positiontup)
pd = self.parameters or {}
pd.update(params)
@@ -781,7 +750,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
def __str__(self):
return self.string
-class ANSISchemaBase(engine.SchemaIterator):
+class DDLBase(engine.SchemaIterator):
def find_alterables(self, tables):
alterables = []
class FindAlterables(schema.SchemaVisitor):
@@ -794,12 +763,12 @@ class ANSISchemaBase(engine.SchemaIterator):
findalterables.traverse(c)
return alterables
-class ANSISchemaGenerator(ANSISchemaBase):
+class SchemaGenerator(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaGenerator, self).__init__(connection, **kwargs)
+ super(SchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables and util.Set(tables) or None
- self.preparer = dialect.preparer()
+ self.preparer = dialect.identifier_preparer
self.dialect = dialect
def get_column_specification(self, column, first_pk=False):
@@ -860,7 +829,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def _compile(self, tocompile, parameters):
"""compile the given string/parameters using this SchemaGenerator's dialect."""
- compiler = self.dialect.compiler(tocompile, parameters)
+ compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
compiler.compile()
return compiler
@@ -930,12 +899,12 @@ class ANSISchemaGenerator(ANSISchemaBase):
string.join([preparer.format_column(c) for c in index.columns], ', ')))
self.execute()
-class ANSISchemaDropper(ANSISchemaBase):
+class SchemaDropper(DDLBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaDropper, self).__init__(connection, **kwargs)
+ super(SchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
- self.preparer = dialect.preparer()
+ self.preparer = dialect.identifier_preparer
self.dialect = dialect
def visit_metadata(self, metadata):
@@ -964,14 +933,11 @@ class ANSISchemaDropper(ANSISchemaBase):
self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
-class ANSIDefaultRunner(engine.DefaultRunner):
- pass
-
-class ANSIIdentifierPreparer(object):
+class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
- """Construct a new ``ANSIIdentifierPreparer`` object.
+ """Construct a new ``IdentifierPreparer`` object.
initial_quote
Character that begins a delimited identifier.
@@ -1049,20 +1015,14 @@ class ANSIIdentifierPreparer(object):
def __generic_obj_format(self, obj, ident):
if getattr(obj, 'quote', False):
return self.quote_identifier(ident)
- if self.dialect.cache_identifiers:
- try:
- return self.__strings[ident]
- except KeyError:
- if self._requires_quotes(ident):
- self.__strings[ident] = self.quote_identifier(ident)
- else:
- self.__strings[ident] = ident
- return self.__strings[ident]
- else:
+ try:
+ return self.__strings[ident]
+ except KeyError:
if self._requires_quotes(ident):
- return self.quote_identifier(ident)
+ self.__strings[ident] = self.quote_identifier(ident)
else:
- return ident
+ self.__strings[ident] = ident
+ return self.__strings[ident]
def should_quote(self, object):
return object.quote or self._requires_quotes(object.name)
@@ -1152,5 +1112,3 @@ class ANSIIdentifierPreparer(object):
return [self._unescape_identifier(i)
for i in [a or b for a, b in r.findall(identifiers)]]
-
-dialect = ANSIDialect
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql/expression.py
index 7c73f7cb7..e117e3f47 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -26,13 +26,14 @@ classes usually have few or no public methods and are less guaranteed
to stay the same in future releases.
"""
-from sqlalchemy import util, exceptions, operators
+from sqlalchemy import util, exceptions
+from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
import re
__all__ = [
- 'Alias', 'ClauseElement', 'ClauseParameters',
- 'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
+ 'Alias', 'ClauseElement',
+ 'ColumnCollection', 'ColumnElement',
'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
'between', 'bindparam', 'case', 'cast', 'column', 'delete',
@@ -810,187 +811,6 @@ def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
-class ClauseParameters(object):
- """Represent a dictionary/iterator of bind parameter key names/values.
-
- Tracks the original [sqlalchemy.sql#_BindParamClause] objects
- as well as the keys/position of each parameter, and can return
- parameters as a dictionary or a list. Will process parameter
- values according to the ``TypeEngine`` objects present in the
- ``_BindParamClause`` instances.
- """
-
- def __init__(self, dialect, positional=None):
- self.dialect = dialect
- self.__binds = {}
- self.positional = positional or []
-
- def get_parameter(self, key):
- return self.__binds[key]
-
- def set_parameter(self, bindparam, value, name):
- self.__binds[name] = [bindparam, name, value]
-
- def get_original(self, key):
- return self.__binds[key][2]
-
- def get_type(self, key):
- return self.__binds[key][0].type
-
- def get_processors(self):
- """return a dictionary of bind 'processing' functions"""
- return dict([
- (key, value) for key, value in
- [(
- key,
- self.__binds[key][0].bind_processor(self.dialect)
- ) for key in self.__binds]
- if value is not None
- ])
-
- def get_processed(self, key, processors):
- return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
-
- def keys(self):
- return self.__binds.keys()
-
- def __iter__(self):
- return iter(self.keys())
-
- def __getitem__(self, key):
- (bind, name, value) = self.__binds[key]
- processor = bind.bind_processor(self.dialect)
- return processor is not None and processor(value) or value
-
- def __contains__(self, key):
- return key in self.__binds
-
- def set_value(self, key, value):
- self.__binds[key][2] = value
-
- def get_original_dict(self):
- return dict([(name, value) for (b, name, value) in self.__binds.values()])
-
- def __get_processed(self, key, processors):
- if key in processors:
- return processors[key](self.__binds[key][2])
- else:
- return self.__binds[key][2]
-
- def get_raw_list(self, processors):
- return [self.__get_processed(key, processors) for key in self.positional]
-
- def get_raw_dict(self, processors, encode_keys=False):
- if encode_keys:
- return dict([
- (
- key.encode(self.dialect.encoding),
- self.__get_processed(key, processors)
- )
- for key in self.keys()
- ])
- else:
- return dict([
- (
- key,
- self.__get_processed(key, processors)
- )
- for key in self.keys()
- ])
-
- def __repr__(self):
- return self.__class__.__name__ + ":" + repr(self.get_original_dict())
-
-class ClauseVisitor(object):
- """A class that knows how to traverse and visit ``ClauseElements``.
-
- Calls visit_XXX() methods dynamically generated for each
- particualr ``ClauseElement`` subclass encountered. Traversal of a
- hierarchy of ``ClauseElements`` is achieved via the ``traverse()``
- method, which is passed the lead ``ClauseElement``.
-
- By default, ``ClauseVisitor`` traverses all elements fully.
- Options can be specified at the class level via the
- ``__traverse_options__`` dictionary which will be passed to the
- ``get_children()`` method of each ``ClauseElement``; these options
- can indicate modifications to the set of elements returned, such
- as to not return column collections (column_collections=False) or
- to return Schema-level items (schema_visitor=True).
-
- ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
- operation, which will produce a copy of a given ``ClauseElement``
- structure while at the same time allowing ``ClauseVisitor``
- subclasses to modify the new structure in-place.
- """
-
- __traverse_options__ = {}
-
- def traverse_single(self, obj, **kwargs):
- meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
- if meth:
- return meth(obj, **kwargs)
-
- def iterate(self, obj, stop_on=None):
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- yield t
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
-
- def traverse(self, obj, stop_on=None, clone=False):
- if clone:
- obj = obj._clone()
-
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- if clone:
- t._copy_internals()
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- v = self
- while v is not None:
- meth = getattr(v, "visit_%s" % target.__visit_name__, None)
- if meth:
- meth(target)
- v = getattr(v, '_next', None)
- return obj
-
- def chain(self, visitor):
- """'chain' an additional ClauseVisitor onto this ClauseVisitor.
-
- The chained visitor will receive all visit events after this one.
- """
-
- tail = self
- while getattr(tail, '_next', None) is not None:
- tail = tail._next
- tail._next = visitor
- return self
-
-class NoColumnVisitor(ClauseVisitor):
- """A ClauseVisitor that will not traverse exported column collections.
-
- Will not traverse the exported Column collections on Table, Alias,
- Select, and CompoundSelect objects (i.e. their 'columns' or 'c'
- attribute).
-
- This is useful because most traversals don't need those columns,
- or in the case of ANSICompiler it traverses them explicitly; so
- skipping their traversal here greatly cuts down on method call
- overhead.
- """
-
- __traverse_options__ = {'column_collections': False}
-
class _FigureVisitName(type):
def __init__(cls, clsname, bases, dict):
@@ -1061,7 +881,7 @@ class ClauseElement(object):
elif len(optionaldict) > 1:
raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument")
- class Vis(ClauseVisitor):
+ class Vis(visitors.ClauseVisitor):
def visit_bindparam(self, bind):
if bind.key in kwargs:
bind.value = kwargs[bind.key]
@@ -1156,7 +976,7 @@ class ClauseElement(object):
if any.
Finally, if there is no bound ``Engine``, uses an
- ``ANSIDialect`` to create a default ``Compiler``.
+ ``DefaultDialect`` to create a default ``Compiler``.
`parameters` is a dictionary representing the default bind
parameters to be used with the statement. If `parameters` is
@@ -1175,15 +995,16 @@ class ClauseElement(object):
if compiler is None:
if dialect is not None:
- compiler = dialect.compiler(self, parameters)
+ compiler = dialect.statement_compiler(dialect, self, parameters)
elif bind is not None:
- compiler = bind.compiler(self, parameters)
+ compiler = bind.statement_compiler(self, parameters)
elif self.bind is not None:
- compiler = self.bind.compiler(self, parameters)
+ compiler = self.bind.statement_compiler(self, parameters)
if compiler is None:
- import sqlalchemy.ansisql as ansisql
- compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters)
+ from sqlalchemy.engine.default import DefaultDialect
+ dialect = DefaultDialect()
+ compiler = dialect.statement_compiler(dialect, self, parameters=parameters)
compiler.compile()
return compiler
@@ -1727,7 +1548,7 @@ class FromClause(Selectable):
def _get_all_embedded_columns(self):
ret = []
- class FindCols(ClauseVisitor):
+ class FindCols(visitors.ClauseVisitor):
def visit_column(self, col):
ret.append(col)
FindCols().traverse(self)
@@ -1744,8 +1565,8 @@ class FromClause(Selectable):
def replace_selectable(self, old, alias):
"""replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
- from sqlalchemy import sql_util
- return sql_util.ClauseAdapter(alias).traverse(self, clone=True)
+ from sqlalchemy.sql import util
+ return util.ClauseAdapter(alias).traverse(self, clone=True)
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
"""Given a ``ColumnElement``, return the exported ``ColumnElement``
@@ -2376,7 +2197,7 @@ class Join(FromClause):
else:
equivs[x] = util.Set([y])
- class BinaryVisitor(ClauseVisitor):
+ class BinaryVisitor(visitors.ClauseVisitor):
def visit_binary(self, binary):
if binary.operator == operators.eq:
add_equiv(binary.left, binary.right)
@@ -2460,7 +2281,7 @@ class Join(FromClause):
return self.__folded_equivalents
if equivs is None:
equivs = util.Set()
- class LocateEquivs(NoColumnVisitor):
+ class LocateEquivs(visitors.NoColumnVisitor):
def visit_binary(self, binary):
if binary.operator == operators.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
@@ -3331,7 +3152,7 @@ class Select(_SelectBaseMixin, FromClause):
return intersect_all(self, other, **kwargs)
def _table_iterator(self):
- for t in NoColumnVisitor().iterate(self):
+ for t in visitors.NoColumnVisitor().iterate(self):
if isinstance(t, TableClause):
yield t
diff --git a/lib/sqlalchemy/operators.py b/lib/sqlalchemy/sql/operators.py
index b8aca3d26..b8aca3d26 100644
--- a/lib/sqlalchemy/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql/util.py
index cc6325822..2c7294e66 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -1,7 +1,100 @@
-from sqlalchemy import sql, util, schema, topological
+from sqlalchemy import util, schema, topological
+from sqlalchemy.sql import expression, visitors
"""Utility functions that build upon SQL and Schema constructs."""
+class ClauseParameters(object):
+ """Represent a dictionary/iterator of bind parameter key names/values.
+
+ Tracks the original [sqlalchemy.sql#_BindParamClause] objects as well as the
+ keys/position of each parameter, and can return parameters as a
+ dictionary or a list. Will process parameter values according to
+ the ``TypeEngine`` objects present in the ``_BindParamClause`` instances.
+ """
+
+ def __init__(self, dialect, positional=None):
+ self.dialect = dialect
+ self.__binds = {}
+ self.positional = positional or []
+
+ def get_parameter(self, key):
+ return self.__binds[key]
+
+ def set_parameter(self, bindparam, value, name):
+ self.__binds[name] = [bindparam, name, value]
+
+ def get_original(self, key):
+ return self.__binds[key][2]
+
+ def get_type(self, key):
+ return self.__binds[key][0].type
+
+ def get_processors(self):
+ """return a dictionary of bind 'processing' functions"""
+ return dict([
+ (key, value) for key, value in
+ [(
+ key,
+ self.__binds[key][0].bind_processor(self.dialect)
+ ) for key in self.__binds]
+ if value is not None
+ ])
+
+ def get_processed(self, key, processors):
+ return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
+
+ def keys(self):
+ return self.__binds.keys()
+
+ def __iter__(self):
+ return iter(self.keys())
+
+ def __getitem__(self, key):
+ (bind, name, value) = self.__binds[key]
+ processor = bind.bind_processor(self.dialect)
+ return processor is not None and processor(value) or value
+
+ def __contains__(self, key):
+ return key in self.__binds
+
+ def set_value(self, key, value):
+ self.__binds[key][2] = value
+
+ def get_original_dict(self):
+ return dict([(name, value) for (b, name, value) in self.__binds.values()])
+
+ def __get_processed(self, key, processors):
+ if key in processors:
+ return processors[key](self.__binds[key][2])
+ else:
+ return self.__binds[key][2]
+
+ def get_raw_list(self, processors):
+ return [self.__get_processed(key, processors) for key in self.positional]
+
+ def get_raw_dict(self, processors, encode_keys=False):
+ if encode_keys:
+ return dict([
+ (
+ key.encode(self.dialect.encoding),
+ self.__get_processed(key, processors)
+ )
+ for key in self.keys()
+ ])
+ else:
+ return dict([
+ (
+ key,
+ self.__get_processed(key, processors)
+ )
+ for key in self.keys()
+ ])
+
+ def __repr__(self):
+ return self.__class__.__name__ + ":" + repr(self.get_original_dict())
+
+
+
class TableCollection(object):
def __init__(self, tables=None):
self.tables = tables or []
@@ -64,7 +157,7 @@ class TableCollection(object):
return sequence
-class TableFinder(TableCollection, sql.NoColumnVisitor):
+class TableFinder(TableCollection, visitors.NoColumnVisitor):
"""locate all Tables within a clause."""
def __init__(self, clause, check_columns=False, include_aliases=False):
@@ -85,7 +178,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor):
if self.check_columns:
self.tables.append(column.table)
-class ColumnFinder(sql.ClauseVisitor):
+class ColumnFinder(visitors.ClauseVisitor):
def __init__(self):
self.columns = util.Set()
@@ -95,7 +188,7 @@ class ColumnFinder(sql.ClauseVisitor):
def __iter__(self):
return iter(self.columns)
-class ColumnsInClause(sql.ClauseVisitor):
+class ColumnsInClause(visitors.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns
from the clause are in the selectable.
"""
@@ -108,7 +201,7 @@ class ColumnsInClause(sql.ClauseVisitor):
if self.selectable.c.get(column.key) is column:
self.result = True
-class AbstractClauseProcessor(sql.NoColumnVisitor):
+class AbstractClauseProcessor(visitors.NoColumnVisitor):
"""Traverse a clause and attempt to convert the contents of container elements
to a converted element.
@@ -224,10 +317,10 @@ class ClauseAdapter(AbstractClauseProcessor):
self.equivalents = equivalents
def convert_element(self, col):
- if isinstance(col, sql.FromClause):
+ if isinstance(col, expression.FromClause):
if self.selectable.is_derived_from(col):
return self.selectable
- if not isinstance(col, sql.ColumnElement):
+ if not isinstance(col, expression.ColumnElement):
return None
if self.include is not None:
if col not in self.include:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
new file mode 100644
index 000000000..98e4de6c3
--- /dev/null
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -0,0 +1,87 @@
+class ClauseVisitor(object):
+ """A class that knows how to traverse and visit
+ ``ClauseElements``.
+
+ Calls visit_XXX() methods dynamically generated for each particualr
+ ``ClauseElement`` subclass encountered. Traversal of a
+ hierarchy of ``ClauseElements`` is achieved via the
+ ``traverse()`` method, which is passed the lead
+ ``ClauseElement``.
+
+ By default, ``ClauseVisitor`` traverses all elements
+ fully. Options can be specified at the class level via the
+ ``__traverse_options__`` dictionary which will be passed
+ to the ``get_children()`` method of each ``ClauseElement``;
+ these options can indicate modifications to the set of
+ elements returned, such as to not return column collections
+ (column_collections=False) or to return Schema-level items
+ (schema_visitor=True).
+
+ ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+ operation, which will produce a copy of a given ``ClauseElement``
+ structure while at the same time allowing ``ClauseVisitor`` subclasses
+ to modify the new structure in-place.
+
+ """
+ __traverse_options__ = {}
+
+ def traverse_single(self, obj, **kwargs):
+ meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kwargs)
+
+ def iterate(self, obj, stop_on=None):
+ stack = [obj]
+ traversal = []
+ while len(stack) > 0:
+ t = stack.pop()
+ if stop_on is None or t not in stop_on:
+ yield t
+ traversal.insert(0, t)
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
+
+ def traverse(self, obj, stop_on=None, clone=False):
+ if clone:
+ obj = obj._clone()
+
+ stack = [obj]
+ traversal = []
+ while len(stack) > 0:
+ t = stack.pop()
+ if stop_on is None or t not in stop_on:
+ traversal.insert(0, t)
+ if clone:
+ t._copy_internals()
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
+ for target in traversal:
+ v = self
+ while v is not None:
+ meth = getattr(v, "visit_%s" % target.__visit_name__, None)
+ if meth:
+ meth(target)
+ v = getattr(v, '_next', None)
+ return obj
+
+ def chain(self, visitor):
+ """'chain' an additional ClauseVisitor onto this ClauseVisitor.
+
+ the chained visitor will receive all visit events after this one."""
+ tail = self
+ while getattr(tail, '_next', None) is not None:
+ tail = tail._next
+ tail._next = visitor
+ return self
+
+class NoColumnVisitor(ClauseVisitor):
+ """a ClauseVisitor that will not traverse the exported Column
+ collections on Table, Alias, Select, and CompoundSelect objects
+ (i.e. their 'columns' or 'c' attribute).
+
+ this is useful because most traversals don't need those columns, or
+ in the case of DefaultCompiler it traverses them explicitly; so
+ skipping their traversal here greatly cuts down on method call overhead.
+ """
+
+ __traverse_options__ = {'column_collections':False}