diff options
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/informix.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sybase.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 57 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/dependency.py | 69 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/sync.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 172 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 78 |
13 files changed, 194 insertions, 203 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 4e95b5bf3..38dba17a5 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -5,11 +5,11 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import random from sqlalchemy import sql, schema, types, exceptions, pool from sqlalchemy.sql import compiler, expression from sqlalchemy.engine import default, base + class AcNumeric(types.Numeric): def result_processor(self, dialect): return None diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 03fbcd67e..400a2761e 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -8,7 +8,7 @@ import datetime -from sqlalchemy import sql, schema, exceptions, pool +from sqlalchemy import sql, schema, exceptions, pool, util from sqlalchemy.sql import compiler from sqlalchemy.engine import default from sqlalchemy import types as sqltypes diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index b9706da74..ef941ed87 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -38,7 +38,7 @@ Known issues / TODO: """ -import datetime, operator, random, re, sys +import datetime, operator, re, sys from sqlalchemy import sql, schema, exceptions, util from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py index bf2b6b7d6..2551e90c5 100644 --- a/lib/sqlalchemy/databases/sybase.py +++ b/lib/sqlalchemy/databases/sybase.py @@ -22,7 +22,7 @@ Known issues / TODO: * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) """ -import datetime, operator, random +import datetime, operator from sqlalchemy import util, sql, schema, exceptions from sqlalchemy.sql import compiler, expression diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index f426d93a6..28951f900 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -14,7 +14,7 @@ and result contexts. import StringIO, sys from sqlalchemy import exceptions, schema, util, types, logging -from sqlalchemy.sql import expression, visitors +from sqlalchemy.sql import expression class Dialect(object): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 5c6a67b28..cc4eb60d2 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -14,7 +14,6 @@ as the base class for their own corresponding classes. import re, random -from sqlalchemy import util from sqlalchemy.engine import base from sqlalchemy.sql import compiler, expression @@ -40,7 +39,7 @@ class DefaultDialect(base.Dialect): preexecute_pk_sequences = False supports_pk_autoincrement = True dbapi_type_map = {} - + def __init__(self, convert_unicode=False, assert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode self.assert_unicode = assert_unicode @@ -81,7 +80,7 @@ class DefaultDialect(base.Dialect): typeobj = typeobj() return typeobj - + def oid_column_name(self, column): return None @@ -105,15 +104,15 @@ class DefaultDialect(base.Dialect): """ connection.commit() - + def create_xid(self): """Create a random two-phase transaction ID. - + This id will be passed to do_begin_twophase(), do_rollback_twophase(), do_commit_twophase(). Its format is unspecified.""" - + return "_sa_%032x" % random.randint(0,2**128) - + def do_savepoint(self, connection, name): connection.execute(expression.SavepointClause(name)) @@ -139,27 +138,27 @@ class DefaultExecutionContext(base.ExecutionContext): self._connection = self.root_connection = connection self.compiled = compiled self.engine = connection.engine - + if compiled is not None: # compiled clauseelement. process bind params, process table defaults, # track collections used by ResultProxy to target and process results - + self.processors = dict([ - (key, value) for key, value in + (key, value) for key, value in [( compiled.bind_names[bindparam], bindparam.bind_processor(self.dialect) ) for bindparam in compiled.bind_names] if value is not None ]) - + self.result_map = compiled.result_map - + if not dialect.supports_unicode_statements: self.statement = unicode(compiled).encode(self.dialect.encoding) else: self.statement = unicode(compiled) - + self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate if isinstance(compiled.statement, expression._TextClause): @@ -168,7 +167,7 @@ class DefaultExecutionContext(base.ExecutionContext): else: self.returns_rows = self.returns_rows_compiled(compiled) self.should_autocommit = getattr(compiled.statement, '_autocommit', False) or self.should_autocommit_compiled(compiled) - + if not parameters: self.compiled_parameters = [compiled.construct_params()] self.executemany = False @@ -181,7 +180,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.parameters = self.__convert_compiled_params(self.compiled_parameters) elif statement is not None: - # plain text statement. + # plain text statement. self.result_map = None self.parameters = self.__encode_param_keys(parameters) self.executemany = len(parameters) > 1 @@ -198,14 +197,14 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = None self.isinsert = self.isupdate = self.executemany = self.returns_rows = self.should_autocommit = False self.cursor = self.create_cursor() - + connection = property(lambda s:s._connection._branch()) - + def __encode_param_keys(self, params): """apply string encoding to the keys of dictionary-based bind parameters. - + This is only used executing textual, non-compiled SQL expressions.""" - + if self.dialect.positional or self.dialect.supports_unicode_statements: if params: return params @@ -226,7 +225,7 @@ class DefaultExecutionContext(base.ExecutionContext): """convert the dictionary of bind parameter values into a dict or list to be sent to the DBAPI's execute() or executemany() method. """ - + processors = self.processors parameters = [] if self.dialect.positional: @@ -257,10 +256,10 @@ class DefaultExecutionContext(base.ExecutionContext): param[key] = compiled_params[key] parameters.append(param) return parameters - + def returns_rows_compiled(self, compiled): return isinstance(compiled.statement, expression.Selectable) - + def returns_rows_text(self, statement): return SELECT_REGEXP.match(statement) @@ -276,10 +275,10 @@ class DefaultExecutionContext(base.ExecutionContext): def pre_execution(self): self.pre_exec() - + def post_execution(self): self.post_exec() - + def result(self): return self.get_result_proxy() @@ -330,7 +329,7 @@ class DefaultExecutionContext(base.ExecutionContext): if self.dialect.positional: inputsizes = [] for key in self.compiled.positiontup: - typeengine = types[key] + typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes.append(dbtype) @@ -362,7 +361,7 @@ class DefaultExecutionContext(base.ExecutionContext): drunner = self.dialect.defaultrunner(self) params = self.compiled_parameters for param in params: - # assign each dict of params to self.compiled_parameters; + # assign each dict of params to self.compiled_parameters; # this allows user-defined default generators to access the full # set of bind params for the row self.compiled_parameters = param @@ -374,17 +373,17 @@ class DefaultExecutionContext(base.ExecutionContext): if val is not None: param[c.key] = val self.compiled_parameters = params - + else: compiled_parameters = self.compiled_parameters[0] drunner = self.dialect.defaultrunner(self) - + for c in self.compiled.prefetch: if self.isinsert: val = drunner.get_column_default(c) else: val = drunner.get_column_onupdate(c) - + if val is not None: compiled_parameters[c.key] = val diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 707165215..8b41d93dd 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -10,10 +10,10 @@ dependencies at flush time. """ -from sqlalchemy.orm import sync, attributes +from sqlalchemy.orm import sync from sqlalchemy.orm.sync import ONETOMANY,MANYTOONE,MANYTOMANY from sqlalchemy import sql, util, exceptions -from sqlalchemy.orm import session as sessionlib + def create_dependency_processor(prop): types = { @@ -28,7 +28,7 @@ def create_dependency_processor(prop): class DependencyProcessor(object): no_dependencies = False - + def __init__(self, prop): self.prop = prop self.cascade = prop.cascade @@ -55,12 +55,12 @@ class DependencyProcessor(object): return getattr(self.parent.class_, self.key) def hasparent(self, state): - """return True if the given object instance has a parent, + """return True if the given object instance has a parent, according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.""" - + # TODO: use correct API for this return self._get_instrumented_attribute().impl.hasparent(state) - + def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on which, with regards to the two or three mappers handled by @@ -113,7 +113,7 @@ class DependencyProcessor(object): return if state is not None and not self.mapper._canload(state): raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper)) - + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): """Called during a flush to synchronize primary key identifier values between a parent/child object, as well as to an @@ -124,7 +124,7 @@ class DependencyProcessor(object): def _compile_synchronizers(self): """Assemble a list of *synchronization rules*. - + These are fired to populate attributes from one side of a relation to another. """ @@ -156,10 +156,10 @@ class DependencyProcessor(object): if x is not None: uowcommit.register_object(state, postupdate=True, post_update_cols=self.syncrules.dest_columns()) break - + def _pks_changed(self, uowcommit, state): return self.syncrules.source_changes(uowcommit, state) - + def __str__(self): return "%s(%s)" % (self.__class__.__name__, str(self.prop)) @@ -205,12 +205,12 @@ class OneToManyDP(DependencyProcessor): for child in deleted: if not self.cascade.delete_orphan and not self.hasparent(child): self._synchronize(state, child, None, True, uowcommit) - + if self._pks_changed(uowcommit, state): if unchanged: for child in unchanged: self._synchronize(state, child, None, False, uowcommit) - + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction) @@ -247,7 +247,7 @@ class OneToManyDP(DependencyProcessor): if unchanged: for child in unchanged: uowcommit.register_object(child) - + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): if child is not None: child = getattr(child, '_state', child) @@ -263,45 +263,45 @@ class DetectKeySwitch(DependencyProcessor): child items who have changed their referenced key.""" no_dependencies = True - + def register_dependencies(self, uowcommit): uowcommit.register_processor(self.parent, self, self.mapper) - + def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): # for non-passive updates, register in the preprocess stage # so that mapper save_obj() gets a hold of changes if not delete and not self.passive_updates: self._process_key_switches(deplist, uowcommit) - + def process_dependencies(self, task, deplist, uowcommit, delete=False): # for passive updates, register objects in the process stage - # so that we avoid ManyToOneDP's registering the object without + # so that we avoid ManyToOneDP's registering the object without # the listonly flag in its own preprocess stage (results in UPDATE) # statements being emitted if not delete and self.passive_updates: self._process_key_switches(deplist, uowcommit) - - def _process_key_switches(self, deplist, uowcommit): + + def _process_key_switches(self, deplist, uowcommit): switchers = util.Set([s for s in deplist if self._pks_changed(uowcommit, s)]) if switchers: - # yes, we're doing a linear search right now through the UOW. only + # yes, we're doing a linear search right now through the UOW. only # takes effect when primary key values have actually changed. # a possible optimization might be to enhance the "hasparents" capability of # attributes to actually store all parent references, but this introduces # more complicated attribute accounting. - for s in [elem for elem in uowcommit.session.identity_map.all_states() - if issubclass(elem.class_, self.parent.class_) and - self.key in elem.dict and + for s in [elem for elem in uowcommit.session.identity_map.all_states() + if issubclass(elem.class_, self.parent.class_) and + self.key in elem.dict and elem.dict[self.key]._state in switchers ]: uowcommit.register_object(s, listonly=self.passive_updates) self.syncrules.execute(s.dict[self.key]._state, s, None, None, False) - + class ManyToOneDP(DependencyProcessor): def __init__(self, prop): DependencyProcessor.__init__(self, prop) self.mapper._dependency_processors.append(DetectKeySwitch(prop)) - + def register_dependencies(self, uowcommit): if self.post_update: if not self.is_backref: @@ -312,7 +312,7 @@ class ManyToOneDP(DependencyProcessor): else: uowcommit.register_dependency(self.mapper, self.parent) uowcommit.register_processor(self.mapper, self, self.parent) - + def process_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) @@ -387,12 +387,12 @@ class ManyToManyDP(DependencyProcessor): secondary_delete = [] secondary_insert = [] secondary_update = [] - + if self.prop._reverse_property: reverse_dep = getattr(self.prop._reverse_property, '_dependency_processor', None) else: reverse_dep = None - + if delete: for state in deplist: (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) @@ -422,13 +422,13 @@ class ManyToManyDP(DependencyProcessor): self._synchronize(state, child, associationrow, False, uowcommit) uowcommit.attributes[(self, "manytomany", state, child)] = True secondary_delete.append(associationrow) - + if not self.passive_updates and unchanged and self._pks_changed(uowcommit, state): for child in unchanged: associationrow = {} self.syncrules.update(associationrow, state, child, "old_") secondary_update.append(associationrow) - + if secondary_delete: secondary_delete.sort() # TODO: precompile the delete/insert queries? @@ -436,13 +436,13 @@ class ManyToManyDP(DependencyProcessor): result = connection.execute(statement, secondary_delete) if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete))) - + if secondary_update: statement = self.secondary.update(sql.and_(*[c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow])) result = connection.execute(statement, secondary_update) if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update): raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update))) - + if secondary_insert: statement = self.secondary.insert() connection.execute(statement, secondary_insert) @@ -481,7 +481,7 @@ class MapperStub(object): """ __metaclass__ = util.ArgSingleton - + def __init__(self, parent, mapper, key): self.mapper = mapper self.base_mapper = self @@ -490,7 +490,7 @@ class MapperStub(object): def polymorphic_iterator(self): return iter([self]) - + def _register_dependencies(self, uowcommit): pass @@ -502,4 +502,3 @@ class MapperStub(object): def primary_mapper(self): return self - diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index aa3a42291..2c3a98c88 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -13,9 +13,9 @@ to mappers. The remainder of this module is generally private to the ORM. """ -from sqlalchemy import util, logging, exceptions -from sqlalchemy.sql import expression from itertools import chain +from sqlalchemy import exceptions, logging +from sqlalchemy.sql import expression class_mapper = None __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension', diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index a80252c84..d95009a47 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -12,7 +12,7 @@ clause that compares column values. from sqlalchemy import schema, exceptions, util from sqlalchemy.sql import visitors, operators from sqlalchemy import logging -from sqlalchemy.orm import util as mapperutil, attributes +from sqlalchemy.orm import util as mapperutil ONETOMANY = 0 MANYTOONE = 1 diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index a801210f9..9ddfcd278 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -8,7 +8,7 @@ from sqlalchemy import sql, util, exceptions from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.util import row_adapter as create_row_adapter from sqlalchemy.sql import visitors -from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, build_path +from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none"]) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3f32778d6..8d8cfa38f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -7,21 +7,20 @@ """Base SQL and DDL compiler implementations. Provides the [sqlalchemy.sql.compiler#DefaultCompiler] class, which is -responsible for generating all SQL query strings, as well as +responsible for generating all SQL query strings, as well as [sqlalchemy.sql.compiler#SchemaGenerator] and [sqlalchemy.sql.compiler#SchemaDropper] which issue CREATE and DROP DDL for tables, sequences, and indexes. The elements in this module are used by public-facing constructs like [sqlalchemy.sql.expression#ClauseElement] and [sqlalchemy.engine#Engine]. While dialect authors will want to be familiar with this module for the purpose of -creating database-specific compilers and schema generators, the module +creating database-specific compilers and schema generators, the module is otherwise internal to SQLAlchemy. """ import string, re from sqlalchemy import schema, engine, util, exceptions -from sqlalchemy.sql import operators, visitors, functions -from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import operators, functions from sqlalchemy.sql import expression as sql RESERVED_WORDS = util.Set([ @@ -57,7 +56,7 @@ BIND_TEMPLATES = { 'numeric':"%(position)s", 'named':":%(name)s" } - + OPERATORS = { operators.and_ : 'AND', @@ -96,14 +95,14 @@ OPERATORS = { FUNCTIONS = { functions.coalesce : 'coalesce%(expr)s', - functions.current_date: 'CURRENT_DATE', - functions.current_time: 'CURRENT_TIME', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', functions.current_timestamp: 'CURRENT_TIMESTAMP', - functions.current_user: 'CURRENT_USER', - functions.localtime: 'LOCALTIME', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', functions.localtimestamp: 'LOCALTIMESTAMP', functions.sysdate: 'sysdate', - functions.session_user :'SESSION_USER', + functions.session_user :'SESSION_USER', functions.user: 'USER' } @@ -118,7 +117,7 @@ class DefaultCompiler(engine.Compiled): operators = OPERATORS functions = FUNCTIONS - + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -132,35 +131,35 @@ class DefaultCompiler(engine.Compiled): a list of column names to be compiled into an INSERT or UPDATE statement. """ - + super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs) # if we are insert/update/delete. set to true when we visit an INSERT, UPDATE or DELETE self.isdelete = self.isinsert = self.isupdate = False - + # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) - + # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} - + # a dictionary of _BindParamClause instances to "compiled" names that are # actually present in the generated SQL self.bind_names = {} # a stack. what recursive compiler doesn't have a stack ? :) self.stack = [] - + # relates label names in the final SQL to # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. # ResultProxy uses this for type processing and column targeting self.result_map = {} - + # a dictionary of ClauseElement subclasses to counters, which are used to # generate truncated identifier names or "anonymous" identifiers such as # for aliases self.generated_ids = {} - + # paramstyle from the dialect (comes from DB-API) self.paramstyle = self.dialect.paramstyle @@ -168,17 +167,17 @@ class DefaultCompiler(engine.Compiled): self.positional = self.dialect.positional self.bindtemplate = BIND_TEMPLATES[self.paramstyle] - + # a list of the compiled's bind parameter names, used to help # formulate a positional argument list self.positiontup = [] # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer - + def compile(self): self.string = self.process(self.statement) - + def process(self, obj, stack=None, **kwargs): if stack: self.stack.append(stack) @@ -189,23 +188,23 @@ class DefaultCompiler(engine.Compiled): finally: if stack: self.stack.pop(-1) - + def is_subquery(self, select): return self.stack and self.stack[-1].get('is_subquery') - + def get_whereclause(self, obj): - """given a FROM clause, return an additional WHERE condition that should be - applied to a SELECT. - + """given a FROM clause, return an additional WHERE condition that should be + applied to a SELECT. + Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN constructs in non-ansi mode. """ - + return None def construct_params(self, params=None): """return a dictionary of bind parameter keys and values""" - + if params: pd = {} for bindparam, name in self.bind_names.iteritems(): @@ -218,9 +217,9 @@ class DefaultCompiler(engine.Compiled): return pd else: return dict([(self.bind_names[bindparam], bindparam.value) for bindparam in self.bind_names]) - + params = property(construct_params) - + def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -228,22 +227,22 @@ class DefaultCompiler(engine.Compiled): """ return "" - + def visit_grouping(self, grouping, **kwargs): return "(" + self.process(grouping.elem) + ")" - + def visit_label(self, label, result_map=None): labelname = self._truncated_identifier("colident", label.name) - + if result_map is not None: result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type) - + return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) - + def visit_column(self, column, result_map=None, use_schema=False, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily - # want to truncate a column identifier, if its mapped to the name of a - # physical column. but thats very hard to identify at this point, and + # want to truncate a column identifier, if its mapped to the name of a + # physical column. but thats very hard to identify at this point, and # the identifier length should be greater than the id lengths of any physical # columns so should not matter. @@ -259,7 +258,7 @@ class DefaultCompiler(engine.Compiled): if result_map is not None: result_map[name.lower()] = (name, (column, ), column.type) - + if column._is_oid: n = self.dialect.oid_column_name(column) if n is not None: @@ -288,7 +287,7 @@ class DefaultCompiler(engine.Compiled): # TODO: some dialects might need different behavior here return text.replace('%', '%%') - + def visit_fromclause(self, fromclause, **kwargs): return fromclause.name @@ -302,7 +301,7 @@ class DefaultCompiler(engine.Compiled): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): self.result_map[colname.lower()] = (colname, None, type_) - + def do_bindparam(m): name = m.group(1) if name in textclause.bindparams: @@ -311,7 +310,7 @@ class DefaultCompiler(engine.Compiled): return self.bindparam_string(name) # un-escape any \:params - return BIND_PARAMS_ESC.sub(lambda m: m.group(1), + return BIND_PARAMS_ESC.sub(lambda m: m.group(1), BIND_PARAMS.sub(do_bindparam, textclause.text) ) @@ -339,37 +338,37 @@ class DefaultCompiler(engine.Compiled): result_map[func.name.lower()] = (func.name, None, func.type) name = self.function_string(func) - + if callable(name): return name(*[self.process(x) for x in func.clause_expr]) else: return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} - + def function_argspec(self, func): return self.process(func.clause_expr) - + def function_string(self, func): return self.functions.get(func.__class__, func.name + "%(expr)s") def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): stack_entry = {'select':cs} - + if asfrom: stack_entry['is_subquery'] = True elif self.stack and self.stack[-1].get('select'): stack_entry['is_subquery'] = True self.stack.append(stack_entry) - + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by - text += self.order_by_clause(cs) + text += self.order_by_clause(cs) text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" - + self.stack.pop(-1) - + if asfrom and parens: return "(" + text + ")" else: @@ -382,19 +381,17 @@ class DefaultCompiler(engine.Compiled): if unary.modifier: s = s + " " + self.operator_string(unary.modifier) return s - + def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) if callable(op): return op(self.process(binary.left), self.process(binary.right)) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) - - return ret - + def operator_string(self, operator): return self.operators.get(operator, str(operator)) - + def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: @@ -403,22 +400,22 @@ class DefaultCompiler(engine.Compiled): raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) - + def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] - + bind_name = bindparam.key bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation self.bind_names[bindparam] = bind_name - + return bind_name - + def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] - + anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) if len(anonname) > self.dialect.max_identifier_length: @@ -441,14 +438,14 @@ class DefaultCompiler(engine.Compiled): self.generated_ids[('anon_counter', derived)] = anonymous_counter + 1 self.generated_ids[key] = newname return newname - + def _anonymize(self, name): return ANONYMOUS_LABEL.sub(self._process_anon, name) - + def bindparam_string(self, name): if self.positional: self.positiontup.append(name) - + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} def visit_alias(self, alias, asfrom=False, **kwargs): @@ -459,13 +456,13 @@ class DefaultCompiler(engine.Compiled): def label_select_column(self, select, column, asfrom): """label columns present in a select().""" - + if isinstance(column, sql._Label): return column - + if select.use_labels and getattr(column, '_label', None): return column.label(column._label) - + if \ asfrom and \ isinstance(column, sql._ColumnClause) and \ @@ -494,12 +491,12 @@ class DefaultCompiler(engine.Compiled): stack_entry['iswrapper'] = True else: column_clause_args = {'result_map':self.result_map} - + if self.stack and 'from' in self.stack[-1]: existingfroms = self.stack[-1]['from'] else: existingfroms = None - + froms = select._get_display_froms(existingfroms) correlate_froms = util.Set() @@ -510,17 +507,17 @@ class DefaultCompiler(engine.Compiled): # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost # if existingfroms: -# correlate_froms = correlate_froms.union(existingfroms) +# correlate_froms = correlate_froms.union(existingfroms) stack_entry['from'] = correlate_froms self.stack.append(stack_entry) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet() - + for co in select.inner_columns: l = self.label_select_column(select, co, asfrom=asfrom) inner_columns.add(self.process(l, **column_clause_args)) - + collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " @@ -539,7 +536,7 @@ class DefaultCompiler(engine.Compiled): whereclause = sql.and_(w, whereclause) else: whereclause = w - + if froms: text += " \nFROM " text += string.join(from_strings, ', ') @@ -559,7 +556,7 @@ class DefaultCompiler(engine.Compiled): t = self.process(select._having) if t: text += " \nHAVING " + t - + text += self.order_by_clause(select) text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) @@ -625,10 +622,10 @@ class DefaultCompiler(engine.Compiled): ', '.join([preparer.quote(c[0], c[0].name) for c in colparams]), ', '.join([c[1] for c in colparams]))) - + def visit_update(self, update_stmt): self.stack.append({'from':util.Set([update_stmt.table])}) - + self.isupdate = True colparams = self._get_colparams(update_stmt) @@ -636,15 +633,15 @@ class DefaultCompiler(engine.Compiled): if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) - + self.stack.pop(-1) - + return text def _get_colparams(self, stmt): - """create a set of tuples representing column/string pairs for use + """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. - + """ def create_bind_param(col, value): @@ -654,7 +651,7 @@ class DefaultCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] - + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: @@ -688,7 +685,7 @@ class DefaultCompiler(engine.Compiled): if (((isinstance(c.default, schema.Sequence) and not c.default.optional) or not self.dialect.supports_pk_autoincrement) or - (c.default is not None and + (c.default is not None and not isinstance(c.default, schema.Sequence))): values.append((c, create_bind_param(c, None))) self.prefetch.append(c) @@ -732,18 +729,18 @@ class DefaultCompiler(engine.Compiled): text += " WHERE " + self.process(delete_stmt._whereclause) self.stack.pop(-1) - + return text - + def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - + def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - + def __str__(self): return self.string or '' @@ -1072,10 +1069,10 @@ class IdentifierPreparer(object): def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name. - + deprecated. use preparer.quote(col, column.name) or combine with format_table() """ - + if name is None: name = column.name if not getattr(column, 'is_literal', False): @@ -1121,7 +1118,6 @@ class IdentifierPreparer(object): 'final': final, 'escaped': escaped_final }) self._r_identifiers = r - + return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] - diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0b7684803..b39e406da 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -25,8 +25,7 @@ classes usually have few or no public methods and are less guaranteed to stay the same in future releases. """ -import datetime, re -import itertools +import itertools, re from sqlalchemy import util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 2cd0a26fd..70a1dcc96 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,5 +1,5 @@ -from sqlalchemy import util, schema, topological -from sqlalchemy.sql import expression, visitors, operators +from sqlalchemy import exceptions, schema, topological, util +from sqlalchemy.sql import expression, operators, visitors from itertools import chain """Utility functions that build upon SQL and Schema constructs.""" @@ -30,16 +30,16 @@ def find_tables(clause, check_columns=False, include_aliases=False): def visit_alias(alias): tables.append(alias) kwargs['visit_alias'] = visit_alias - + if check_columns: def visit_column(column): tables.append(column.table) kwargs['visit_column'] = visit_column - + def visit_table(table): tables.append(table) kwargs['visit_table'] = visit_table - + visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) return tables @@ -49,26 +49,26 @@ def find_columns(clause): cols.add(col) visitors.traverse(clause, visit_column=visit_column) return cols - - + + def reduce_columns(columns, *clauses): """given a list of columns, return a 'reduced' set based on natural equivalents. the set is reduced to the smallest list of columns which have no natural equivalent present in the list. A "natural equivalent" means that two columns will ultimately represent the same value because they are related by a foreign key. - + \*clauses is an optional list of join clauses which will be traversed to further identify columns that are "equivalent". - + This function is primarily used to determine the most minimal "primary key" from a selectable, by reducing the set of primary key columns present in the the selectable to just those that are not repeated. - + """ - + columns = util.OrderedSet(columns) - + omit = util.Set() for col in columns: for fk in col.foreign_keys: @@ -78,7 +78,7 @@ def reduce_columns(columns, *clauses): if fk.column.shares_lineage(c): omit.add(col) break - + if clauses: def visit_binary(binary): if binary.operator == operators.eq: @@ -90,7 +90,7 @@ def reduce_columns(columns, *clauses): break for clause in clauses: visitors.traverse(clause, visit_binary=visit_binary) - + return expression.ColumnSet(columns.difference(omit)) def row_adapter(from_, to, equivalent_columns=None): @@ -133,7 +133,7 @@ def row_adapter(from_, to, equivalent_columns=None): return map.keys() AliasedRow.map = map return AliasedRow - + class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns from the clause are in the selectable. @@ -149,16 +149,16 @@ class ColumnsInClause(visitors.ClauseVisitor): class AbstractClauseProcessor(object): """Traverse and copy a ClauseElement, replacing selected elements based on rules. - + This class implements its own visit-and-copy strategy but maintains the same public interface as visitors.ClauseVisitor. """ - + __traverse_options__ = {'column_collections':False} - + def __init__(self, stop_on=None): self.stop_on = stop_on - + def convert_element(self, elem): """Define the *conversion* method for this ``AbstractClauseProcessor``.""" @@ -166,14 +166,14 @@ class AbstractClauseProcessor(object): def chain(self, visitor): # chaining AbstractClauseProcessor and other ClauseVisitor - # objects separately. All the ACP objects are chained on + # objects separately. All the ACP objects are chained on # their convert_element() method whereas regular visitors # chain on their visit_XXX methods. if isinstance(visitor, AbstractClauseProcessor): attr = '_next_acp' else: attr = '_next' - + tail = self while getattr(tail, attr, None) is not None: tail = getattr(tail, attr) @@ -182,7 +182,7 @@ class AbstractClauseProcessor(object): def copy_and_process(self, list_): """Copy the given list to a new list, with each element traversed individually.""" - + list_ = list(list_) stop_on = util.Set(self.stop_on or []) cloned = {} @@ -198,44 +198,44 @@ class AbstractClauseProcessor(object): stop_on.add(newelem) return newelem v = getattr(v, '_next_acp', None) - + if elem not in cloned: # the full traversal will only make a clone of a particular element # once. cloned[elem] = elem._clone() return cloned[elem] - + def traverse(self, elem, clone=True): if not clone: raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True") - + return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True) - + def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False): if elem in stop_on: return elem - + if _clone_toplevel: elem = self._convert_element(elem, stop_on, cloned) if elem in stop_on: return elem - + def clone(element): return self._convert_element(element, stop_on, cloned) elem._copy_internals(clone=clone) - + v = getattr(self, '_next', None) while v is not None: meth = getattr(v, "visit_%s" % elem.__visit_name__, None) if meth: meth(elem) v = getattr(v, '_next', None) - + for e in elem.get_children(**self.__traverse_options__): if e not in stop_on: self._traverse(e, stop_on, cloned) return elem - + class ClauseAdapter(AbstractClauseProcessor): """Given a clause (like as in a WHERE criterion), locate columns which are embedded within a given selectable, and changes those @@ -273,23 +273,23 @@ class ClauseAdapter(AbstractClauseProcessor): def copy_and_chain(self, adapter): """create a copy of this adapter and chain to the given adapter. - + currently this adapter must be unchained to start, raises - an exception if it's already chained. - + an exception if it's already chained. + Does not modify the given adapter. """ - + if adapter is None: return self - + if hasattr(self, '_next_acp') or hasattr(self, '_next'): raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") - + ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) ca._next_acp = adapter return ca - + def convert_element(self, col): if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): @@ -309,5 +309,3 @@ class ClauseAdapter(AbstractClauseProcessor): if newcol: return newcol return newcol - - |