summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-12-18 17:57:15 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-12-18 17:57:15 +0000
commitbe5d3263436b81fb179c8189f1064d477d5fb3e6 (patch)
tree7f99d53445ef85d4bce4fcf6b5e244779cbcde1c /lib/sqlalchemy
parent98d7d70674b443d1691971926af1b1db4d7101dc (diff)
downloadsqlalchemy-be5d3263436b81fb179c8189f1064d477d5fb3e6.tar.gz
merged -r5299:5438 of py3k warnings branch. this fixes some sqlite py2.6 testing issues,
and also addresses a significant chunk of py3k deprecations. It's mainly expicit __hash__ methods. Additionally, most usage of sets/dicts to store columns uses util-based placeholder names.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/databases/mysql.py6
-rw-r--r--lib/sqlalchemy/engine/base.py20
-rw-r--r--lib/sqlalchemy/engine/url.py3
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py6
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py4
-rw-r--r--lib/sqlalchemy/orm/attributes.py8
-rw-r--r--lib/sqlalchemy/orm/collections.py27
-rw-r--r--lib/sqlalchemy/orm/mapper.py26
-rw-r--r--lib/sqlalchemy/orm/properties.py20
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/orm/strategies.py6
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py8
-rw-r--r--lib/sqlalchemy/orm/util.py7
-rw-r--r--lib/sqlalchemy/pool.py4
-rw-r--r--lib/sqlalchemy/schema.py10
-rw-r--r--lib/sqlalchemy/sql/compiler.py13
-rw-r--r--lib/sqlalchemy/sql/expression.py20
-rw-r--r--lib/sqlalchemy/sql/util.py17
-rw-r--r--lib/sqlalchemy/sql/visitors.py6
-rw-r--r--lib/sqlalchemy/types.py2
-rw-r--r--lib/sqlalchemy/util.py83
21 files changed, 182 insertions, 118 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index 0fc7c8fbd..9c6c48e0f 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -1034,7 +1034,7 @@ class _BinaryType(sqltypes.Binary):
if value is None:
return None
else:
- return buffer(value)
+ return util.buffer(value)
return process
class MSVarBinary(_BinaryType):
@@ -1081,7 +1081,7 @@ class MSBinary(_BinaryType):
if value is None:
return None
else:
- return buffer(value)
+ return util.buffer(value)
return process
class MSBlob(_BinaryType):
@@ -1108,7 +1108,7 @@ class MSBlob(_BinaryType):
if value is None:
return None
else:
- return buffer(value)
+ return util.buffer(value)
return process
def __repr__(self):
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index c7101d10e..dbcd5b76b 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1267,7 +1267,6 @@ class Engine(Connectable):
return self.pool.unique_connection()
-
def _proxy_connection_cls(cls, proxy):
class ProxyConnection(cls):
def execute(self, object, *multiparams, **params):
@@ -1319,6 +1318,8 @@ class RowProxy(object):
for i in xrange(len(self.__row)):
yield self.__parent._get_col(self.__row, i)
+ __hash__ = None
+
def __eq__(self, other):
return ((other is self) or
(other == tuple(self.__parent._get_col(self.__row, key)
@@ -1347,18 +1348,23 @@ class RowProxy(object):
def items(self):
"""Return a list of tuples, each tuple containing a key/value pair."""
- return [(key, getattr(self, key)) for key in self.keys()]
+ return [(key, getattr(self, key)) for key in self.iterkeys()]
def keys(self):
"""Return the list of keys as strings represented by this RowProxy."""
return self.__parent.keys
-
+
+ def iterkeys(self):
+ return iter(self.__parent.keys)
+
def values(self):
"""Return the values represented by this RowProxy as a list."""
return list(self)
-
+
+ def itervalues(self):
+ return iter(self)
class BufferedColumnRow(RowProxy):
def __init__(self, parent, row):
@@ -1425,7 +1431,7 @@ class ResultProxy(object):
return
self._rowcount = None
- self._props = util.PopulateDict(None)
+ self._props = util.populate_column_dict(None)
self._props.creator = self.__key_fallback()
self.keys = []
@@ -1848,7 +1854,7 @@ class DefaultRunner(schema.SchemaVisitor):
def visit_column_onupdate(self, onupdate):
if isinstance(onupdate.arg, expression.ClauseElement):
return self.exec_default_sql(onupdate)
- elif callable(onupdate.arg):
+ elif util.callable(onupdate.arg):
return onupdate.arg(self.context)
else:
return onupdate.arg
@@ -1856,7 +1862,7 @@ class DefaultRunner(schema.SchemaVisitor):
def visit_column_default(self, default):
if isinstance(default.arg, expression.ClauseElement):
return self.exec_default_sql(default)
- elif callable(default.arg):
+ elif util.callable(default.arg):
return default.arg(self.context)
else:
return default.arg
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index 044d701ac..5c8e68ce4 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -71,6 +71,9 @@ class URL(object):
s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
return s
+ def __hash__(self):
+ return hash(str(self))
+
def __eq__(self, other):
return \
isinstance(other, URL) and \
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 33eb5d240..315142d8e 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -487,7 +487,7 @@ class _AssociationList(object):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
- if (callable(func) and func.func_name == func_name and
+ if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
@@ -663,7 +663,7 @@ class _AssociationDict(object):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
- if (callable(func) and func.func_name == func_name and
+ if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(dict, func_name)):
func.__doc__ = getattr(dict, func_name).__doc__
del func_name, func
@@ -890,7 +890,7 @@ class _AssociationSet(object):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
- if (callable(func) and func.func_name == func_name and
+ if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(set, func_name)):
func.__doc__ = getattr(set, func_name).__doc__
del func_name, func
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
index e59b577e3..a5d60bf82 100644
--- a/lib/sqlalchemy/ext/orderinglist.py
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -65,7 +65,7 @@ ORM-compatible constructor for `OrderingList` instances.
"""
from sqlalchemy.orm.collections import collection
-
+from sqlalchemy import util
__all__ = [ 'ordering_list' ]
@@ -272,7 +272,7 @@ class OrderingList(list):
self._reorder()
for func_name, func in locals().items():
- if (callable(func) and func.func_name == func_name and
+ if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 79be76c3a..f113a4eb9 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -175,7 +175,7 @@ def proxied_attribute_factory(descriptor):
@property
def comparator(self):
- if callable(self._comparator):
+ if util.callable(self._comparator):
self._comparator = self._comparator()
return self._comparator
@@ -838,7 +838,7 @@ class InstanceState(object):
@property
def sort_key(self):
- return self.key and self.key[1] or self.insert_order
+ return self.key and self.key[1] or (self.insert_order, )
def check_modified(self):
if self.modified:
@@ -958,7 +958,7 @@ class InstanceState(object):
"""a set of keys which have no uncommitted changes"""
return set(
- key for key in self.manager.keys()
+ key for key in self.manager.iterkeys()
if (key not in self.committed_state or
(key in self.manager.mutable_attributes and
not self.manager[key].impl.check_mutable_modified(self))))
@@ -972,7 +972,7 @@ class InstanceState(object):
"""
return set(
- key for key in self.manager.keys()
+ key for key in self.manager.iterkeys()
if key not in self.committed_state and key not in self.dict)
def expire_attributes(self, attribute_names):
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index 2105a4fe6..3c1c16b7d 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -105,15 +105,14 @@ import weakref
import sqlalchemy.exceptions as sa_exc
from sqlalchemy.sql import expression
-from sqlalchemy import schema
-import sqlalchemy.util as sautil
+from sqlalchemy import schema, util
__all__ = ['collection', 'collection_adapter',
'mapped_collection', 'column_mapped_collection',
'attribute_mapped_collection']
-__instrumentation_mutex = sautil.threading.Lock()
+__instrumentation_mutex = util.threading.Lock()
def column_mapped_collection(mapping_spec):
@@ -131,7 +130,7 @@ def column_mapped_collection(mapping_spec):
from sqlalchemy.orm.util import _state_mapper
from sqlalchemy.orm.attributes import instance_state
- cols = [expression._no_literals(q) for q in sautil.to_list(mapping_spec)]
+ cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)]
if len(cols) == 1:
def keyfunc(value):
state = instance_state(value)
@@ -511,8 +510,8 @@ class CollectionAdapter(object):
if converter is not None:
return converter(obj)
- setting_type = sautil.duck_type_collection(obj)
- receiving_type = sautil.duck_type_collection(self._data())
+ setting_type = util.duck_type_collection(obj)
+ receiving_type = util.duck_type_collection(self._data())
if obj is None or setting_type != receiving_type:
given = obj is None and 'None' or obj.__class__.__name__
@@ -637,7 +636,7 @@ def bulk_replace(values, existing_adapter, new_adapter):
if not isinstance(values, list):
values = list(values)
- idset = sautil.IdentitySet
+ idset = util.IdentitySet
constants = idset(existing_adapter or ()).intersection(values or ())
additions = idset(values or ()).difference(constants)
removals = idset(existing_adapter or ()).difference(constants)
@@ -739,7 +738,7 @@ def _instrument_class(cls):
"Can not instrument a built-in type. Use a "
"subclass, even a trivial one.")
- collection_type = sautil.duck_type_collection(cls)
+ collection_type = util.duck_type_collection(cls)
if collection_type in __interfaces:
roles = __interfaces[collection_type].copy()
decorators = roles.pop('_decorators', {})
@@ -753,7 +752,7 @@ def _instrument_class(cls):
for name in dir(cls):
method = getattr(cls, name, None)
- if not callable(method):
+ if not util.callable(method):
continue
# note role declarations
@@ -825,7 +824,7 @@ def _instrument_membership_mutator(method, before, argument, after):
"""Route method args and/or return value through the collection adapter."""
# This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
if before:
- fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0]))
+ fn_args = list(util.flatten_iterator(inspect.getargspec(method)[0]))
if type(argument) is int:
pos_arg = argument
named_arg = len(fn_args) > argument and fn_args[argument] or None
@@ -1040,7 +1039,7 @@ def _dict_decorators():
setattr(fn, '_sa_instrumented', True)
fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__')
- Unspecified = sautil.symbol('Unspecified')
+ Unspecified = util.symbol('Unspecified')
def __setitem__(fn):
def __setitem__(self, key, value, _sa_initiator=None):
@@ -1138,7 +1137,7 @@ def _set_binops_check_strict(self, obj):
def _set_binops_check_loose(self, obj):
"""Allow anything set-like to participate in set binops."""
return (isinstance(obj, _set_binop_bases + (self.__class__,)) or
- sautil.duck_type_collection(obj) == set)
+ util.duck_type_collection(obj) == set)
def _set_decorators():
@@ -1148,7 +1147,7 @@ def _set_decorators():
setattr(fn, '_sa_instrumented', True)
fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__')
- Unspecified = sautil.symbol('Unspecified')
+ Unspecified = util.symbol('Unspecified')
def add(fn):
def add(self, value, _sa_initiator=None):
@@ -1405,7 +1404,7 @@ class MappedCollection(dict):
have assigned for that value.
"""
- for incoming_key, value in sautil.dictlike_iteritems(dictlike):
+ for incoming_key, value in util.dictlike_iteritems(dictlike):
new_key = self.keyfunc(value)
if incoming_key != new_key:
raise TypeError(
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index cc3517f75..ca6dec689 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -409,8 +409,8 @@ class Mapper(object):
self._pks_by_table = {}
self._cols_by_table = {}
- all_cols = set(chain(*[col.proxy_set for col in self._columntoproperty]))
- pk_cols = set(c for c in all_cols if c.primary_key)
+ all_cols = util.column_set(chain(*[col.proxy_set for col in self._columntoproperty]))
+ pk_cols = util.column_set(c for c in all_cols if c.primary_key)
# identify primary key columns which are also mapped by this mapper.
tables = set(self.tables + [self.mapped_table])
@@ -418,8 +418,8 @@ class Mapper(object):
for t in tables:
if t.primary_key and pk_cols.issuperset(t.primary_key):
# ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get())
- self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols)
- self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols)
+ self._pks_by_table[t] = util.ordered_column_set(t.primary_key).intersection(pk_cols)
+ self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(all_cols)
# determine cols that aren't expressed within our tables; mark these
# as "read only" properties which are refreshed upon INSERT/UPDATE
@@ -470,7 +470,7 @@ class Mapper(object):
# table columns mapped to lists of MapperProperty objects
# using a list allows a single column to be defined as
# populating multiple object attributes
- self._columntoproperty = {}
+ self._columntoproperty = util.column_dict()
# load custom properties
if self._init_properties:
@@ -891,7 +891,7 @@ class Mapper(object):
"""
params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key]
- return sql.and_(*[k==v for (k, v) in params]), dict(params)
+ return sql.and_(*[k==v for (k, v) in params]), util.column_dict(params)
@util.memoized_property
def _equivalent_columns(self):
@@ -915,17 +915,17 @@ class Mapper(object):
"""
- result = {}
+ result = util.column_dict()
def visit_binary(binary):
if binary.operator == operators.eq:
if binary.left in result:
result[binary.left].add(binary.right)
else:
- result[binary.left] = set((binary.right,))
+ result[binary.left] = util.column_set((binary.right,))
if binary.right in result:
result[binary.right].add(binary.left)
else:
- result[binary.right] = set((binary.left,))
+ result[binary.right] = util.column_set((binary.left,))
for mapper in self.base_mapper.polymorphic_iterator():
if mapper.inherit_condition:
visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary})
@@ -1232,7 +1232,7 @@ class Mapper(object):
for t in mapper.tables:
table_to_mapper[t] = mapper
- for table in sqlutil.sort_tables(table_to_mapper.keys()):
+ for table in sqlutil.sort_tables(table_to_mapper.iterkeys()):
insert = []
update = []
@@ -1282,7 +1282,7 @@ class Mapper(object):
if col is mapper.version_id_col:
params[col._label] = mapper._get_state_attr_by_column(state, col)
params[col.key] = params[col._label] + 1
- for prop in mapper._columntoproperty.values():
+ for prop in mapper._columntoproperty.itervalues():
history = attributes.get_history(state, prop.key, passive=True)
if history.added:
hasdata = True
@@ -1432,7 +1432,7 @@ class Mapper(object):
for t in mapper.tables:
table_to_mapper[t] = mapper
- for table in reversed(sqlutil.sort_tables(table_to_mapper.keys())):
+ for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())):
delete = {}
for state, mapper, connection in tups:
if table not in mapper._pks_by_table:
@@ -1666,7 +1666,7 @@ class Mapper(object):
"""Produce a collection of attribute level row processor callables."""
new_populators, existing_populators = [], []
- for prop in self._props.values():
+ for prop in self._props.itervalues():
newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter)
if newpop:
new_populators.append((prop.key, newpop))
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index ad42117e1..084a539d1 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -112,7 +112,7 @@ class CompositeProperty(ColumnProperty):
util.warn_deprecated("The 'comparator' argument to CompositeProperty is deprecated. Use comparator_factory.")
kwargs['comparator_factory'] = kwargs['comparator']
super(CompositeProperty, self).__init__(*columns, **kwargs)
- self._col_position_map = dict((c, i) for i, c in enumerate(columns))
+ self._col_position_map = util.column_dict((c, i) for i, c in enumerate(columns))
self.composite_class = class_
self.strategy_class = strategies.CompositeColumnLoader
@@ -159,7 +159,9 @@ class CompositeProperty(ColumnProperty):
return expression.ClauseList(*[self.adapter(x) for x in self.prop.columns])
else:
return expression.ClauseList(*self.prop.columns)
-
+
+ __hash__ = None
+
def __eq__(self, other):
if other is None:
values = [None] * len(self.prop.columns)
@@ -363,6 +365,8 @@ class RelationProperty(StrategizedProperty):
raise NotImplementedError("in_() not yet supported for relations. For a "
"simple many-to-one, use in_() against the set of foreign key values.")
+ __hash__ = None
+
def __eq__(self, other):
if other is None:
if self.prop.direction in [ONETOMANY, MANYTOMANY]:
@@ -583,7 +587,7 @@ class RelationProperty(StrategizedProperty):
self.mapper = mapper.class_mapper(self.argument, compile=False)
elif isinstance(self.argument, mapper.Mapper):
self.mapper = self.argument
- elif callable(self.argument):
+ elif util.callable(self.argument):
# accept a callable to suit various deferred-configurational schemes
self.mapper = mapper.class_mapper(self.argument(), compile=False)
else:
@@ -592,7 +596,7 @@ class RelationProperty(StrategizedProperty):
# accept callables for other attributes which may require deferred initialization
for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'):
- if callable(getattr(self, attr)):
+ if util.callable(getattr(self, attr)):
setattr(self, attr, getattr(self, attr)())
# in the case that InstrumentedAttributes were used to construct
@@ -607,8 +611,8 @@ class RelationProperty(StrategizedProperty):
if self.order_by:
self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)]
- self._foreign_keys = set(expression._literal_as_column(x) for x in util.to_set(self._foreign_keys))
- self.remote_side = set(expression._literal_as_column(x) for x in util.to_set(self.remote_side))
+ self._foreign_keys = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self._foreign_keys))
+ self.remote_side = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self.remote_side))
if not self.parent.concrete:
for inheriting in self.parent.iterate_to_root():
@@ -727,7 +731,7 @@ class RelationProperty(StrategizedProperty):
else:
self.secondary_synchronize_pairs = None
- self._foreign_keys = set(r for l, r in self.synchronize_pairs)
+ self._foreign_keys = util.column_set(r for l, r in self.synchronize_pairs)
if self.secondary_synchronize_pairs:
self._foreign_keys.update(r for l, r in self.secondary_synchronize_pairs)
@@ -814,7 +818,7 @@ class RelationProperty(StrategizedProperty):
"Specify remote_side argument to indicate which column lazy "
"join condition should bind." % (r, self.mapper))
- self.local_side, self.remote_side = [util.OrderedSet(x) for x in zip(*list(self.local_remote_pairs))]
+ self.local_side, self.remote_side = [util.ordered_column_set(x) for x in zip(*list(self.local_remote_pairs))]
def _post_init(self):
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index da33eac41..5a0c3faff 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -753,7 +753,7 @@ class Query(object):
"""
aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
if kwargs:
- raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys()))
+ raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys()))
return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint)
@util.accepts_a_list_as_starargs(list_deprecation='pending')
@@ -766,7 +766,7 @@ class Query(object):
"""
aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
if kwargs:
- raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys()))
+ raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys()))
return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint)
@_generative(__no_statement_condition, __no_limit_offset)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 58aa71c6a..a159e4bfa 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -451,9 +451,9 @@ class LazyLoader(AbstractRelationLoader):
return (new_execute, None)
def _create_lazy_clause(cls, prop, reverse_direction=False):
- binds = {}
- lookup = {}
- equated_columns = {}
+ binds = util.column_dict()
+ lookup = util.column_dict()
+ equated_columns = util.column_dict()
if reverse_direction and not prop.secondaryjoin:
for l, r in prop.local_remote_pairs:
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 778bf0949..4efab88ae 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -269,7 +269,7 @@ class UOWTransaction(object):
def elements(self):
"""Iterate UOWTaskElements."""
- for task in self.tasks.values():
+ for task in self.tasks.itervalues():
for elem in task.elements:
yield elem
@@ -288,7 +288,7 @@ class UOWTransaction(object):
def _sort_dependencies(self):
nodes = topological.sort_with_cycles(self.dependencies,
- [t.mapper for t in self.tasks.values() if t.base_task is t]
+ [t.mapper for t in self.tasks.itervalues() if t.base_task is t]
)
ret = []
@@ -565,7 +565,7 @@ class UOWTask(object):
# as part of the topological sort itself, which would
# eliminate the need for this step (but may make the original
# topological sort more expensive)
- head = topological.sort_as_tree(tuples, object_to_original_task.keys())
+ head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys())
if head is not None:
original_to_tasks = {}
stack = [(head, t)]
@@ -585,7 +585,7 @@ class UOWTask(object):
task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete)
if state in dependencies:
- task.cyclical_dependencies.update(dependencies[state].values())
+ task.cyclical_dependencies.update(dependencies[state].itervalues())
stack += [(n, task) for n in children]
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 541adf4e4..411c827c6 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -4,8 +4,6 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import new
-
import sqlalchemy.exceptions as sa_exc
from sqlalchemy import sql, util
from sqlalchemy.sql import expression, util as sql_util, operators
@@ -329,7 +327,7 @@ class AliasedClass(object):
if hasattr(attr, 'func_code'):
is_method = getattr(self.__target, key, None)
if is_method and is_method.im_self is not None:
- return new.instancemethod(attr.im_func, self, self)
+ return util.types.MethodType(attr.im_func, self, self)
else:
return None
elif hasattr(attr, '__get__'):
@@ -570,7 +568,8 @@ def _is_mapped_class(cls):
from sqlalchemy.orm import mapperlib as mapper
if isinstance(cls, (AliasedClass, mapper.Mapper)):
return True
-
+ if isinstance(cls, expression.ClauseElement):
+ return False
manager = attributes.manager_of_class(cls)
return manager and _INSTRUMENTOR in manager.info
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index 1d99215dc..6aa8b0395 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -814,14 +814,14 @@ class AssertionPool(Pool):
return "AssertionPool"
def create_connection(self):
- raise "Invalid"
+ raise AssertionError("Invalid")
def do_return_conn(self, conn):
assert conn is self._conn and self.connection is None
self.connection = conn
def do_return_invalid(self, conn):
- raise "Invalid"
+ raise AssertionError("Invalid")
def do_get(self):
assert self.connection is not None
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index dc523b36c..5fa84063f 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -574,7 +574,7 @@ class Column(SchemaItem, expression.ColumnClause):
coltype = args[0]
# adjust for partials
- if callable(coltype):
+ if util.callable(coltype):
coltype = args[0]()
if (isinstance(coltype, types.AbstractType) or
@@ -963,7 +963,7 @@ class ColumnDefault(DefaultGenerator):
if isinstance(arg, FetchedValue):
raise exc.ArgumentError(
"ColumnDefault may not be a server-side default type.")
- if callable(arg):
+ if util.callable(arg):
arg = self._maybe_wrap_callable(arg)
self.arg = arg
@@ -1320,6 +1320,8 @@ class PrimaryKeyConstraint(Constraint):
def copy(self, **kw):
return PrimaryKeyConstraint(name=self.name, *[c.key for c in self])
+ __hash__ = Constraint.__hash__
+
def __eq__(self, other):
return self.columns == other
@@ -1663,7 +1665,7 @@ class MetaData(SchemaItem):
if only is None:
load = [name for name in available if name not in current]
- elif callable(only):
+ elif util.callable(only):
load = [name for name in available
if name not in current and only(name, self)]
else:
@@ -1940,7 +1942,7 @@ class DDL(object):
"Expected a string or unicode SQL statement, got '%r'" %
statement)
if (on is not None and
- (not isinstance(on, basestring) and not callable(on))):
+ (not isinstance(on, basestring) and not util.callable(on))):
raise exc.ArgumentError(
"Expected the name of a database dialect or a callable for "
"'on' criteria, got type '%s'." % type(on).__name__)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 747978e76..921d932d2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -162,7 +162,7 @@ class DefaultCompiler(engine.Compiled):
# a dictionary of _BindParamClause instances to "compiled" names that are
# actually present in the generated SQL
- self.bind_names = {}
+ self.bind_names = util.column_dict()
# stack which keeps track of nested SELECT statements
self.stack = []
@@ -205,6 +205,7 @@ class DefaultCompiler(engine.Compiled):
"""return a dictionary of bind parameter keys and values"""
if params:
+ params = util.column_dict(params)
pd = {}
for bindparam, name in self.bind_names.iteritems():
for paramname in (bindparam, bindparam.key, bindparam.shortname, name):
@@ -212,7 +213,7 @@ class DefaultCompiler(engine.Compiled):
pd[name] = params[paramname]
break
else:
- if callable(bindparam.value):
+ if util.callable(bindparam.value):
pd[name] = bindparam.value()
else:
pd[name] = bindparam.value
@@ -220,7 +221,7 @@ class DefaultCompiler(engine.Compiled):
else:
pd = {}
for bindparam in self.bind_names:
- if callable(bindparam.value):
+ if util.callable(bindparam.value):
pd[self.bind_names[bindparam]] = bindparam.value()
else:
pd[self.bind_names[bindparam]] = bindparam.value
@@ -317,7 +318,7 @@ class DefaultCompiler(engine.Compiled):
sep = clauselist.operator
if sep is None:
sep = " "
- elif sep == operators.comma_op:
+ elif sep is operators.comma_op:
sep = ', '
else:
sep = " " + self.operator_string(clauselist.operator) + " "
@@ -336,7 +337,7 @@ class DefaultCompiler(engine.Compiled):
name = self.function_string(func)
- if callable(name):
+ if util.callable(name):
return name(*[self.process(x) for x in func.clauses])
else:
return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
@@ -377,7 +378,7 @@ class DefaultCompiler(engine.Compiled):
def visit_binary(self, binary, **kwargs):
op = self.operator_string(binary.operator)
- if callable(op):
+ if util.callable(op):
return op(self.process(binary.left), self.process(binary.right), **binary.modifiers)
else:
return self.process(binary.left) + " " + op + " " + self.process(binary.right)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 0f7f62e74..a4ff72b1a 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -997,7 +997,7 @@ class ClauseElement(Visitable):
of transformative operations.
"""
- s = set()
+ s = util.column_set()
f = self
while f is not None:
s.add(f)
@@ -1258,6 +1258,8 @@ class ColumnOperators(Operators):
def __le__(self, other):
return self.operate(operators.le, other)
+ __hash__ = Operators.__hash__
+
def __eq__(self, other):
return self.operate(operators.eq, other)
@@ -1580,12 +1582,12 @@ class ColumnElement(ClauseElement, _CompareMixin):
@util.memoized_property
def base_columns(self):
- return set(c for c in self.proxy_set
+ return util.column_set(c for c in self.proxy_set
if not hasattr(c, 'proxies'))
@util.memoized_property
def proxy_set(self):
- s = set([self])
+ s = util.column_set([self])
if hasattr(self, 'proxies'):
for c in self.proxies:
s.update(c.proxy_set)
@@ -1694,6 +1696,8 @@ class ColumnCollection(util.OrderedProperties):
for c in iter:
self.add(c)
+ __hash__ = None
+
def __eq__(self, other):
l = []
for c in other:
@@ -1711,9 +1715,9 @@ class ColumnCollection(util.OrderedProperties):
# have to use a Set here, because it will compare the identity
# of the column, not just using "==" for comparison which will always return a
# "True" value (i.e. a BinaryClause...)
- return col in set(self)
+ return col in util.column_set(self)
-class ColumnSet(util.OrderedSet):
+class ColumnSet(util.ordered_column_set):
def contains_column(self, col):
return col in self
@@ -1733,7 +1737,7 @@ class ColumnSet(util.OrderedSet):
return and_(*l)
def __hash__(self):
- return hash(tuple(self._list))
+ return hash(tuple(x for x in self))
class Selectable(ClauseElement):
"""mark a class as being selectable"""
@@ -1985,7 +1989,7 @@ class _BindParamClause(ColumnElement):
d = self.__dict__.copy()
v = self.value
- if callable(v):
+ if util.callable(v):
v = v()
d['value'] = v
return d
@@ -2369,7 +2373,7 @@ class _BinaryExpression(ColumnElement):
def self_group(self, against=None):
# use small/large defaults for comparison so that unknown
# operators are always parenthesized
- if self.operator != against and operators.is_precedent(self.operator, against):
+ if self.operator is not against and operators.is_precedent(self.operator, against):
return _Grouping(self)
else:
return self
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 9b9b9ec09..d0ca0b01f 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -7,6 +7,7 @@ from itertools import chain
def sort_tables(tables):
"""sort a collection of Table objects in order of their foreign-key dependency."""
+ tables = list(tables)
tuples = []
def visit_foreign_key(fkey):
if fkey.use_alter:
@@ -60,7 +61,7 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join
def find_columns(clause):
"""locate Column objects within the given expression."""
- cols = set()
+ cols = util.column_set()
def visit_column(col):
cols.add(col)
visitors.traverse(clause, {}, {'column':visit_column})
@@ -182,7 +183,7 @@ class Annotated(object):
# to this object's __dict__.
clone.__dict__.update(self.__dict__)
return Annotated(clone, self._annotations)
-
+
def __hash__(self):
return hash(self.__element)
@@ -279,9 +280,9 @@ def reduce_columns(columns, *clauses, **kw):
"""
ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
- columns = util.OrderedSet(columns)
+ columns = util.column_set(columns)
- omit = set()
+ omit = util.column_set()
for col in columns:
for fk in col.foreign_keys:
for c in columns:
@@ -301,7 +302,7 @@ def reduce_columns(columns, *clauses, **kw):
if clauses:
def visit_binary(binary):
if binary.operator == operators.eq:
- cols = set(chain(*[c.proxy_set for c in columns.difference(omit)]))
+ cols = util.column_set(chain(*[c.proxy_set for c in columns.difference(omit)]))
if binary.left in cols and binary.right in cols:
for c in columns:
if c.shares_lineage(binary.right):
@@ -444,7 +445,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
self.selectable = selectable
self.include = include
self.exclude = exclude
- self.equivalents = equivalents or {}
+ self.equivalents = util.column_dict(equivalents or {})
def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET):
newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
@@ -484,7 +485,7 @@ class ColumnAdapter(ClauseAdapter):
ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
if chain_to:
self.chain(chain_to)
- self.columns = util.PopulateDict(self._locate_col)
+ self.columns = util.populate_column_dict(self._locate_col)
def wrap(self, adapter):
ac = self.__class__.__new__(self.__class__)
@@ -492,7 +493,7 @@ class ColumnAdapter(ClauseAdapter):
ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
- ac.columns = util.PopulateDict(ac._locate_col)
+ ac.columns = util.populate_column_dict(ac._locate_col)
return ac
adapt_clause = ClauseAdapter.traverse
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 17b9c59d5..a5bd497ae 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -207,7 +207,7 @@ def traverse_depthfirst(obj, opts, visitors):
def cloned_traverse(obj, opts, visitors):
"""clone the given expression structure, allowing modifications by visitors."""
- cloned = {}
+ cloned = util.column_dict()
def clone(element):
if element not in cloned:
@@ -234,8 +234,8 @@ def cloned_traverse(obj, opts, visitors):
def replacement_traverse(obj, opts, replace):
"""clone the given expression structure, allowing element replacement by a given replacement function."""
- cloned = {}
- stop_on = set(opts.get('stop_on', []))
+ cloned = util.column_dict()
+ stop_on = util.column_set(opts.get('stop_on', []))
def clone(element):
newelem = replace(element)
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 2604f4e8f..7eda27f4f 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -382,7 +382,7 @@ class Concatenable(object):
def adapt_operator(self, op):
"""Converts an add operator to concat."""
from sqlalchemy.sql import operators
- if op == operators.add:
+ if op is operators.add:
return operators.concat_op
else:
return op
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 8b68fb108..1356fa324 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import inspect, itertools, new, operator, sys, warnings, weakref
+import inspect, itertools, operator, sys, warnings, weakref
import __builtin__
types = __import__('types')
@@ -18,8 +18,13 @@ except ImportError:
import dummy_threading as threading
from dummy_threading import local as ThreadLocal
-if sys.version_info < (2, 6):
+py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
+
+if py3k:
+ set_types = set
+elif sys.version_info < (2, 6):
import sets
+ set_types = set, sets.Set
else:
# 2.6 deprecates sets.Set, but we still need to be able to detect them
# in user code and as return values from DB-APIs
@@ -32,15 +37,24 @@ else:
import sets
warnings.filters.remove(ignore)
-set_types = set, sets.Set
+ set_types = set, sets.Set
EMPTY_SET = frozenset()
-try:
- import cPickle as pickle
-except ImportError:
+if py3k:
import pickle
+else:
+ try:
+ import cPickle as pickle
+ except ImportError:
+ import pickle
+if py3k:
+ def buffer(x):
+ return x # no-op until we figure out what MySQLdb is going to use
+else:
+ buffer = __builtin__.buffer
+
if sys.version_info >= (2, 5):
class PopulateDict(dict):
"""A dict which populates missing values via a creation function.
@@ -70,6 +84,17 @@ else:
self[key] = value = self.creator(key)
return value
+if py3k:
+ def callable(fn):
+ return hasattr(fn, '__call__')
+else:
+ callable = __builtin__.callable
+
+if py3k:
+ from functools import reduce
+else:
+ reduce = __builtin__.reduce
+
try:
from collections import defaultdict
except ImportError:
@@ -125,6 +150,14 @@ def to_set(x):
else:
return x
+def to_column_set(x):
+ if x is None:
+ return column_set()
+ if not isinstance(x, column_set):
+ return column_set(to_list(x))
+ else:
+ return x
+
try:
from functools import update_wrapper
@@ -823,10 +856,11 @@ class IdentitySet(object):
This strategy has edge cases for builtin types- it's possible to have
two 'foo' strings in one of these sets, for example. Use sparingly.
+
"""
_working_set = set
-
+
def __init__(self, iterable=None):
self._members = dict()
if iterable:
@@ -918,7 +952,7 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).union(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).union(_iter_id(iterable)))
return result
def __or__(self, other):
@@ -939,7 +973,7 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).difference(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).difference(_iter_id(iterable)))
return result
def __sub__(self, other):
@@ -960,7 +994,7 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).intersection(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable)))
return result
def __and__(self, other):
@@ -981,9 +1015,12 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).symmetric_difference(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable)))
return result
-
+
+ def _member_id_tuples(self):
+ return ((id(v), v) for v in self._members.itervalues())
+
def __xor__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
@@ -1016,11 +1053,6 @@ class IdentitySet(object):
return '%s(%r)' % (type(self).__name__, self._members.values())
-def _iter_id(iterable):
- """Generator: ((id(o), o) for o in iterable)."""
- for item in iterable:
- yield id(item), item
-
class OrderedIdentitySet(IdentitySet):
class _working_set(OrderedSet):
# a testing pragma: exempt the OIDS working set from the test suite's
@@ -1028,7 +1060,7 @@ class OrderedIdentitySet(IdentitySet):
# but it's safe here: IDS operates on (id, instance) tuples in the
# working set.
__sa_hash_exempt__ = True
-
+
def __init__(self, iterable=None):
IdentitySet.__init__(self)
self._members = OrderedDict()
@@ -1036,6 +1068,19 @@ class OrderedIdentitySet(IdentitySet):
for o in iterable:
self.add(o)
+def _iter_id(iterable):
+ """Generator: ((id(o), o) for o in iterable)."""
+
+ for item in iterable:
+ yield id(item), item
+
+# define collections that are capable of storing
+# ColumnElement objects as hashable keys/elements.
+column_set = set
+column_dict = dict
+ordered_column_set = OrderedSet
+populate_column_dict = PopulateDict
+
def unique_list(seq, compare_with=set):
seen = compare_with()
return [x for x in seq if x not in seen and not seen.add(x)]
@@ -1296,7 +1341,7 @@ def function_named(fn, name):
try:
fn.__name__ = name
except TypeError:
- fn = new.function(fn.func_code, fn.func_globals, name,
+ fn = types.FunctionType(fn.func_code, fn.func_globals, name,
fn.func_defaults, fn.func_closure)
return fn