diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-12-18 17:57:15 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-12-18 17:57:15 +0000 |
commit | be5d3263436b81fb179c8189f1064d477d5fb3e6 (patch) | |
tree | 7f99d53445ef85d4bce4fcf6b5e244779cbcde1c /lib/sqlalchemy | |
parent | 98d7d70674b443d1691971926af1b1db4d7101dc (diff) | |
download | sqlalchemy-be5d3263436b81fb179c8189f1064d477d5fb3e6.tar.gz |
merged -r5299:5438 of py3k warnings branch. this fixes some sqlite py2.6 testing issues,
and also addresses a significant chunk of py3k deprecations. It's mainly
expicit __hash__ methods. Additionally, most usage of sets/dicts to store columns uses
util-based placeholder names.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/url.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/orderinglist.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/collections.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/pool.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/schema.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 17 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/types.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/util.py | 83 |
21 files changed, 182 insertions, 118 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 0fc7c8fbd..9c6c48e0f 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1034,7 +1034,7 @@ class _BinaryType(sqltypes.Binary): if value is None: return None else: - return buffer(value) + return util.buffer(value) return process class MSVarBinary(_BinaryType): @@ -1081,7 +1081,7 @@ class MSBinary(_BinaryType): if value is None: return None else: - return buffer(value) + return util.buffer(value) return process class MSBlob(_BinaryType): @@ -1108,7 +1108,7 @@ class MSBlob(_BinaryType): if value is None: return None else: - return buffer(value) + return util.buffer(value) return process def __repr__(self): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index c7101d10e..dbcd5b76b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1267,7 +1267,6 @@ class Engine(Connectable): return self.pool.unique_connection() - def _proxy_connection_cls(cls, proxy): class ProxyConnection(cls): def execute(self, object, *multiparams, **params): @@ -1319,6 +1318,8 @@ class RowProxy(object): for i in xrange(len(self.__row)): yield self.__parent._get_col(self.__row, i) + __hash__ = None + def __eq__(self, other): return ((other is self) or (other == tuple(self.__parent._get_col(self.__row, key) @@ -1347,18 +1348,23 @@ class RowProxy(object): def items(self): """Return a list of tuples, each tuple containing a key/value pair.""" - return [(key, getattr(self, key)) for key in self.keys()] + return [(key, getattr(self, key)) for key in self.iterkeys()] def keys(self): """Return the list of keys as strings represented by this RowProxy.""" return self.__parent.keys - + + def iterkeys(self): + return iter(self.__parent.keys) + def values(self): """Return the values represented by this RowProxy as a list.""" return list(self) - + + def itervalues(self): + return iter(self) class BufferedColumnRow(RowProxy): def __init__(self, parent, row): @@ -1425,7 +1431,7 @@ class ResultProxy(object): return self._rowcount = None - self._props = util.PopulateDict(None) + self._props = util.populate_column_dict(None) self._props.creator = self.__key_fallback() self.keys = [] @@ -1848,7 +1854,7 @@ class DefaultRunner(schema.SchemaVisitor): def visit_column_onupdate(self, onupdate): if isinstance(onupdate.arg, expression.ClauseElement): return self.exec_default_sql(onupdate) - elif callable(onupdate.arg): + elif util.callable(onupdate.arg): return onupdate.arg(self.context) else: return onupdate.arg @@ -1856,7 +1862,7 @@ class DefaultRunner(schema.SchemaVisitor): def visit_column_default(self, default): if isinstance(default.arg, expression.ClauseElement): return self.exec_default_sql(default) - elif callable(default.arg): + elif util.callable(default.arg): return default.arg(self.context) else: return default.arg diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 044d701ac..5c8e68ce4 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -71,6 +71,9 @@ class URL(object): s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) return s + def __hash__(self): + return hash(str(self)) + def __eq__(self, other): return \ isinstance(other, URL) and \ diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 33eb5d240..315142d8e 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -487,7 +487,7 @@ class _AssociationList(object): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(list, func_name)): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func @@ -663,7 +663,7 @@ class _AssociationDict(object): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(dict, func_name)): func.__doc__ = getattr(dict, func_name).__doc__ del func_name, func @@ -890,7 +890,7 @@ class _AssociationSet(object): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(set, func_name)): func.__doc__ = getattr(set, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index e59b577e3..a5d60bf82 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -65,7 +65,7 @@ ORM-compatible constructor for `OrderingList` instances. """ from sqlalchemy.orm.collections import collection - +from sqlalchemy import util __all__ = [ 'ordering_list' ] @@ -272,7 +272,7 @@ class OrderingList(list): self._reorder() for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(list, func_name)): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 79be76c3a..f113a4eb9 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -175,7 +175,7 @@ def proxied_attribute_factory(descriptor): @property def comparator(self): - if callable(self._comparator): + if util.callable(self._comparator): self._comparator = self._comparator() return self._comparator @@ -838,7 +838,7 @@ class InstanceState(object): @property def sort_key(self): - return self.key and self.key[1] or self.insert_order + return self.key and self.key[1] or (self.insert_order, ) def check_modified(self): if self.modified: @@ -958,7 +958,7 @@ class InstanceState(object): """a set of keys which have no uncommitted changes""" return set( - key for key in self.manager.keys() + key for key in self.manager.iterkeys() if (key not in self.committed_state or (key in self.manager.mutable_attributes and not self.manager[key].impl.check_mutable_modified(self)))) @@ -972,7 +972,7 @@ class InstanceState(object): """ return set( - key for key in self.manager.keys() + key for key in self.manager.iterkeys() if key not in self.committed_state and key not in self.dict) def expire_attributes(self, attribute_names): diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 2105a4fe6..3c1c16b7d 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -105,15 +105,14 @@ import weakref import sqlalchemy.exceptions as sa_exc from sqlalchemy.sql import expression -from sqlalchemy import schema -import sqlalchemy.util as sautil +from sqlalchemy import schema, util __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] -__instrumentation_mutex = sautil.threading.Lock() +__instrumentation_mutex = util.threading.Lock() def column_mapped_collection(mapping_spec): @@ -131,7 +130,7 @@ def column_mapped_collection(mapping_spec): from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [expression._no_literals(q) for q in sautil.to_list(mapping_spec)] + cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)] if len(cols) == 1: def keyfunc(value): state = instance_state(value) @@ -511,8 +510,8 @@ class CollectionAdapter(object): if converter is not None: return converter(obj) - setting_type = sautil.duck_type_collection(obj) - receiving_type = sautil.duck_type_collection(self._data()) + setting_type = util.duck_type_collection(obj) + receiving_type = util.duck_type_collection(self._data()) if obj is None or setting_type != receiving_type: given = obj is None and 'None' or obj.__class__.__name__ @@ -637,7 +636,7 @@ def bulk_replace(values, existing_adapter, new_adapter): if not isinstance(values, list): values = list(values) - idset = sautil.IdentitySet + idset = util.IdentitySet constants = idset(existing_adapter or ()).intersection(values or ()) additions = idset(values or ()).difference(constants) removals = idset(existing_adapter or ()).difference(constants) @@ -739,7 +738,7 @@ def _instrument_class(cls): "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") - collection_type = sautil.duck_type_collection(cls) + collection_type = util.duck_type_collection(cls) if collection_type in __interfaces: roles = __interfaces[collection_type].copy() decorators = roles.pop('_decorators', {}) @@ -753,7 +752,7 @@ def _instrument_class(cls): for name in dir(cls): method = getattr(cls, name, None) - if not callable(method): + if not util.callable(method): continue # note role declarations @@ -825,7 +824,7 @@ def _instrument_membership_mutator(method, before, argument, after): """Route method args and/or return value through the collection adapter.""" # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' if before: - fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0])) + fn_args = list(util.flatten_iterator(inspect.getargspec(method)[0])) if type(argument) is int: pos_arg = argument named_arg = len(fn_args) > argument and fn_args[argument] or None @@ -1040,7 +1039,7 @@ def _dict_decorators(): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') - Unspecified = sautil.symbol('Unspecified') + Unspecified = util.symbol('Unspecified') def __setitem__(fn): def __setitem__(self, key, value, _sa_initiator=None): @@ -1138,7 +1137,7 @@ def _set_binops_check_strict(self, obj): def _set_binops_check_loose(self, obj): """Allow anything set-like to participate in set binops.""" return (isinstance(obj, _set_binop_bases + (self.__class__,)) or - sautil.duck_type_collection(obj) == set) + util.duck_type_collection(obj) == set) def _set_decorators(): @@ -1148,7 +1147,7 @@ def _set_decorators(): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__') - Unspecified = sautil.symbol('Unspecified') + Unspecified = util.symbol('Unspecified') def add(fn): def add(self, value, _sa_initiator=None): @@ -1405,7 +1404,7 @@ class MappedCollection(dict): have assigned for that value. """ - for incoming_key, value in sautil.dictlike_iteritems(dictlike): + for incoming_key, value in util.dictlike_iteritems(dictlike): new_key = self.keyfunc(value) if incoming_key != new_key: raise TypeError( diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index cc3517f75..ca6dec689 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -409,8 +409,8 @@ class Mapper(object): self._pks_by_table = {} self._cols_by_table = {} - all_cols = set(chain(*[col.proxy_set for col in self._columntoproperty])) - pk_cols = set(c for c in all_cols if c.primary_key) + all_cols = util.column_set(chain(*[col.proxy_set for col in self._columntoproperty])) + pk_cols = util.column_set(c for c in all_cols if c.primary_key) # identify primary key columns which are also mapped by this mapper. tables = set(self.tables + [self.mapped_table]) @@ -418,8 +418,8 @@ class Mapper(object): for t in tables: if t.primary_key and pk_cols.issuperset(t.primary_key): # ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get()) - self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols) - self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols) + self._pks_by_table[t] = util.ordered_column_set(t.primary_key).intersection(pk_cols) + self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(all_cols) # determine cols that aren't expressed within our tables; mark these # as "read only" properties which are refreshed upon INSERT/UPDATE @@ -470,7 +470,7 @@ class Mapper(object): # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes - self._columntoproperty = {} + self._columntoproperty = util.column_dict() # load custom properties if self._init_properties: @@ -891,7 +891,7 @@ class Mapper(object): """ params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] - return sql.and_(*[k==v for (k, v) in params]), dict(params) + return sql.and_(*[k==v for (k, v) in params]), util.column_dict(params) @util.memoized_property def _equivalent_columns(self): @@ -915,17 +915,17 @@ class Mapper(object): """ - result = {} + result = util.column_dict() def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: - result[binary.left] = set((binary.right,)) + result[binary.left] = util.column_set((binary.right,)) if binary.right in result: result[binary.right].add(binary.left) else: - result[binary.right] = set((binary.left,)) + result[binary.right] = util.column_set((binary.left,)) for mapper in self.base_mapper.polymorphic_iterator(): if mapper.inherit_condition: visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary}) @@ -1232,7 +1232,7 @@ class Mapper(object): for t in mapper.tables: table_to_mapper[t] = mapper - for table in sqlutil.sort_tables(table_to_mapper.keys()): + for table in sqlutil.sort_tables(table_to_mapper.iterkeys()): insert = [] update = [] @@ -1282,7 +1282,7 @@ class Mapper(object): if col is mapper.version_id_col: params[col._label] = mapper._get_state_attr_by_column(state, col) params[col.key] = params[col._label] + 1 - for prop in mapper._columntoproperty.values(): + for prop in mapper._columntoproperty.itervalues(): history = attributes.get_history(state, prop.key, passive=True) if history.added: hasdata = True @@ -1432,7 +1432,7 @@ class Mapper(object): for t in mapper.tables: table_to_mapper[t] = mapper - for table in reversed(sqlutil.sort_tables(table_to_mapper.keys())): + for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())): delete = {} for state, mapper, connection in tups: if table not in mapper._pks_by_table: @@ -1666,7 +1666,7 @@ class Mapper(object): """Produce a collection of attribute level row processor callables.""" new_populators, existing_populators = [], [] - for prop in self._props.values(): + for prop in self._props.itervalues(): newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter) if newpop: new_populators.append((prop.key, newpop)) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ad42117e1..084a539d1 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -112,7 +112,7 @@ class CompositeProperty(ColumnProperty): util.warn_deprecated("The 'comparator' argument to CompositeProperty is deprecated. Use comparator_factory.") kwargs['comparator_factory'] = kwargs['comparator'] super(CompositeProperty, self).__init__(*columns, **kwargs) - self._col_position_map = dict((c, i) for i, c in enumerate(columns)) + self._col_position_map = util.column_dict((c, i) for i, c in enumerate(columns)) self.composite_class = class_ self.strategy_class = strategies.CompositeColumnLoader @@ -159,7 +159,9 @@ class CompositeProperty(ColumnProperty): return expression.ClauseList(*[self.adapter(x) for x in self.prop.columns]) else: return expression.ClauseList(*self.prop.columns) - + + __hash__ = None + def __eq__(self, other): if other is None: values = [None] * len(self.prop.columns) @@ -363,6 +365,8 @@ class RelationProperty(StrategizedProperty): raise NotImplementedError("in_() not yet supported for relations. For a " "simple many-to-one, use in_() against the set of foreign key values.") + __hash__ = None + def __eq__(self, other): if other is None: if self.prop.direction in [ONETOMANY, MANYTOMANY]: @@ -583,7 +587,7 @@ class RelationProperty(StrategizedProperty): self.mapper = mapper.class_mapper(self.argument, compile=False) elif isinstance(self.argument, mapper.Mapper): self.mapper = self.argument - elif callable(self.argument): + elif util.callable(self.argument): # accept a callable to suit various deferred-configurational schemes self.mapper = mapper.class_mapper(self.argument(), compile=False) else: @@ -592,7 +596,7 @@ class RelationProperty(StrategizedProperty): # accept callables for other attributes which may require deferred initialization for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'): - if callable(getattr(self, attr)): + if util.callable(getattr(self, attr)): setattr(self, attr, getattr(self, attr)()) # in the case that InstrumentedAttributes were used to construct @@ -607,8 +611,8 @@ class RelationProperty(StrategizedProperty): if self.order_by: self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)] - self._foreign_keys = set(expression._literal_as_column(x) for x in util.to_set(self._foreign_keys)) - self.remote_side = set(expression._literal_as_column(x) for x in util.to_set(self.remote_side)) + self._foreign_keys = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self._foreign_keys)) + self.remote_side = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self.remote_side)) if not self.parent.concrete: for inheriting in self.parent.iterate_to_root(): @@ -727,7 +731,7 @@ class RelationProperty(StrategizedProperty): else: self.secondary_synchronize_pairs = None - self._foreign_keys = set(r for l, r in self.synchronize_pairs) + self._foreign_keys = util.column_set(r for l, r in self.synchronize_pairs) if self.secondary_synchronize_pairs: self._foreign_keys.update(r for l, r in self.secondary_synchronize_pairs) @@ -814,7 +818,7 @@ class RelationProperty(StrategizedProperty): "Specify remote_side argument to indicate which column lazy " "join condition should bind." % (r, self.mapper)) - self.local_side, self.remote_side = [util.OrderedSet(x) for x in zip(*list(self.local_remote_pairs))] + self.local_side, self.remote_side = [util.ordered_column_set(x) for x in zip(*list(self.local_remote_pairs))] def _post_init(self): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index da33eac41..5a0c3faff 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -753,7 +753,7 @@ class Query(object): """ aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) if kwargs: - raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys())) + raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint) @util.accepts_a_list_as_starargs(list_deprecation='pending') @@ -766,7 +766,7 @@ class Query(object): """ aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) if kwargs: - raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys())) + raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint) @_generative(__no_statement_condition, __no_limit_offset) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 58aa71c6a..a159e4bfa 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -451,9 +451,9 @@ class LazyLoader(AbstractRelationLoader): return (new_execute, None) def _create_lazy_clause(cls, prop, reverse_direction=False): - binds = {} - lookup = {} - equated_columns = {} + binds = util.column_dict() + lookup = util.column_dict() + equated_columns = util.column_dict() if reverse_direction and not prop.secondaryjoin: for l, r in prop.local_remote_pairs: diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 778bf0949..4efab88ae 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -269,7 +269,7 @@ class UOWTransaction(object): def elements(self): """Iterate UOWTaskElements.""" - for task in self.tasks.values(): + for task in self.tasks.itervalues(): for elem in task.elements: yield elem @@ -288,7 +288,7 @@ class UOWTransaction(object): def _sort_dependencies(self): nodes = topological.sort_with_cycles(self.dependencies, - [t.mapper for t in self.tasks.values() if t.base_task is t] + [t.mapper for t in self.tasks.itervalues() if t.base_task is t] ) ret = [] @@ -565,7 +565,7 @@ class UOWTask(object): # as part of the topological sort itself, which would # eliminate the need for this step (but may make the original # topological sort more expensive) - head = topological.sort_as_tree(tuples, object_to_original_task.keys()) + head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys()) if head is not None: original_to_tasks = {} stack = [(head, t)] @@ -585,7 +585,7 @@ class UOWTask(object): task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) if state in dependencies: - task.cyclical_dependencies.update(dependencies[state].values()) + task.cyclical_dependencies.update(dependencies[state].itervalues()) stack += [(n, task) for n in children] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 541adf4e4..411c827c6 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -4,8 +4,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import new - import sqlalchemy.exceptions as sa_exc from sqlalchemy import sql, util from sqlalchemy.sql import expression, util as sql_util, operators @@ -329,7 +327,7 @@ class AliasedClass(object): if hasattr(attr, 'func_code'): is_method = getattr(self.__target, key, None) if is_method and is_method.im_self is not None: - return new.instancemethod(attr.im_func, self, self) + return util.types.MethodType(attr.im_func, self, self) else: return None elif hasattr(attr, '__get__'): @@ -570,7 +568,8 @@ def _is_mapped_class(cls): from sqlalchemy.orm import mapperlib as mapper if isinstance(cls, (AliasedClass, mapper.Mapper)): return True - + if isinstance(cls, expression.ClauseElement): + return False manager = attributes.manager_of_class(cls) return manager and _INSTRUMENTOR in manager.info diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 1d99215dc..6aa8b0395 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -814,14 +814,14 @@ class AssertionPool(Pool): return "AssertionPool" def create_connection(self): - raise "Invalid" + raise AssertionError("Invalid") def do_return_conn(self, conn): assert conn is self._conn and self.connection is None self.connection = conn def do_return_invalid(self, conn): - raise "Invalid" + raise AssertionError("Invalid") def do_get(self): assert self.connection is not None diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index dc523b36c..5fa84063f 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -574,7 +574,7 @@ class Column(SchemaItem, expression.ColumnClause): coltype = args[0] # adjust for partials - if callable(coltype): + if util.callable(coltype): coltype = args[0]() if (isinstance(coltype, types.AbstractType) or @@ -963,7 +963,7 @@ class ColumnDefault(DefaultGenerator): if isinstance(arg, FetchedValue): raise exc.ArgumentError( "ColumnDefault may not be a server-side default type.") - if callable(arg): + if util.callable(arg): arg = self._maybe_wrap_callable(arg) self.arg = arg @@ -1320,6 +1320,8 @@ class PrimaryKeyConstraint(Constraint): def copy(self, **kw): return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) + __hash__ = Constraint.__hash__ + def __eq__(self, other): return self.columns == other @@ -1663,7 +1665,7 @@ class MetaData(SchemaItem): if only is None: load = [name for name in available if name not in current] - elif callable(only): + elif util.callable(only): load = [name for name in available if name not in current and only(name, self)] else: @@ -1940,7 +1942,7 @@ class DDL(object): "Expected a string or unicode SQL statement, got '%r'" % statement) if (on is not None and - (not isinstance(on, basestring) and not callable(on))): + (not isinstance(on, basestring) and not util.callable(on))): raise exc.ArgumentError( "Expected the name of a database dialect or a callable for " "'on' criteria, got type '%s'." % type(on).__name__) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 747978e76..921d932d2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -162,7 +162,7 @@ class DefaultCompiler(engine.Compiled): # a dictionary of _BindParamClause instances to "compiled" names that are # actually present in the generated SQL - self.bind_names = {} + self.bind_names = util.column_dict() # stack which keeps track of nested SELECT statements self.stack = [] @@ -205,6 +205,7 @@ class DefaultCompiler(engine.Compiled): """return a dictionary of bind parameter keys and values""" if params: + params = util.column_dict(params) pd = {} for bindparam, name in self.bind_names.iteritems(): for paramname in (bindparam, bindparam.key, bindparam.shortname, name): @@ -212,7 +213,7 @@ class DefaultCompiler(engine.Compiled): pd[name] = params[paramname] break else: - if callable(bindparam.value): + if util.callable(bindparam.value): pd[name] = bindparam.value() else: pd[name] = bindparam.value @@ -220,7 +221,7 @@ class DefaultCompiler(engine.Compiled): else: pd = {} for bindparam in self.bind_names: - if callable(bindparam.value): + if util.callable(bindparam.value): pd[self.bind_names[bindparam]] = bindparam.value() else: pd[self.bind_names[bindparam]] = bindparam.value @@ -317,7 +318,7 @@ class DefaultCompiler(engine.Compiled): sep = clauselist.operator if sep is None: sep = " " - elif sep == operators.comma_op: + elif sep is operators.comma_op: sep = ', ' else: sep = " " + self.operator_string(clauselist.operator) + " " @@ -336,7 +337,7 @@ class DefaultCompiler(engine.Compiled): name = self.function_string(func) - if callable(name): + if util.callable(name): return name(*[self.process(x) for x in func.clauses]) else: return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} @@ -377,7 +378,7 @@ class DefaultCompiler(engine.Compiled): def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) - if callable(op): + if util.callable(op): return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0f7f62e74..a4ff72b1a 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -997,7 +997,7 @@ class ClauseElement(Visitable): of transformative operations. """ - s = set() + s = util.column_set() f = self while f is not None: s.add(f) @@ -1258,6 +1258,8 @@ class ColumnOperators(Operators): def __le__(self, other): return self.operate(operators.le, other) + __hash__ = Operators.__hash__ + def __eq__(self, other): return self.operate(operators.eq, other) @@ -1580,12 +1582,12 @@ class ColumnElement(ClauseElement, _CompareMixin): @util.memoized_property def base_columns(self): - return set(c for c in self.proxy_set + return util.column_set(c for c in self.proxy_set if not hasattr(c, 'proxies')) @util.memoized_property def proxy_set(self): - s = set([self]) + s = util.column_set([self]) if hasattr(self, 'proxies'): for c in self.proxies: s.update(c.proxy_set) @@ -1694,6 +1696,8 @@ class ColumnCollection(util.OrderedProperties): for c in iter: self.add(c) + __hash__ = None + def __eq__(self, other): l = [] for c in other: @@ -1711,9 +1715,9 @@ class ColumnCollection(util.OrderedProperties): # have to use a Set here, because it will compare the identity # of the column, not just using "==" for comparison which will always return a # "True" value (i.e. a BinaryClause...) - return col in set(self) + return col in util.column_set(self) -class ColumnSet(util.OrderedSet): +class ColumnSet(util.ordered_column_set): def contains_column(self, col): return col in self @@ -1733,7 +1737,7 @@ class ColumnSet(util.OrderedSet): return and_(*l) def __hash__(self): - return hash(tuple(self._list)) + return hash(tuple(x for x in self)) class Selectable(ClauseElement): """mark a class as being selectable""" @@ -1985,7 +1989,7 @@ class _BindParamClause(ColumnElement): d = self.__dict__.copy() v = self.value - if callable(v): + if util.callable(v): v = v() d['value'] = v return d @@ -2369,7 +2373,7 @@ class _BinaryExpression(ColumnElement): def self_group(self, against=None): # use small/large defaults for comparison so that unknown # operators are always parenthesized - if self.operator != against and operators.is_precedent(self.operator, against): + if self.operator is not against and operators.is_precedent(self.operator, against): return _Grouping(self) else: return self diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9b9b9ec09..d0ca0b01f 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -7,6 +7,7 @@ from itertools import chain def sort_tables(tables): """sort a collection of Table objects in order of their foreign-key dependency.""" + tables = list(tables) tuples = [] def visit_foreign_key(fkey): if fkey.use_alter: @@ -60,7 +61,7 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join def find_columns(clause): """locate Column objects within the given expression.""" - cols = set() + cols = util.column_set() def visit_column(col): cols.add(col) visitors.traverse(clause, {}, {'column':visit_column}) @@ -182,7 +183,7 @@ class Annotated(object): # to this object's __dict__. clone.__dict__.update(self.__dict__) return Annotated(clone, self._annotations) - + def __hash__(self): return hash(self.__element) @@ -279,9 +280,9 @@ def reduce_columns(columns, *clauses, **kw): """ ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) - columns = util.OrderedSet(columns) + columns = util.column_set(columns) - omit = set() + omit = util.column_set() for col in columns: for fk in col.foreign_keys: for c in columns: @@ -301,7 +302,7 @@ def reduce_columns(columns, *clauses, **kw): if clauses: def visit_binary(binary): if binary.operator == operators.eq: - cols = set(chain(*[c.proxy_set for c in columns.difference(omit)])) + cols = util.column_set(chain(*[c.proxy_set for c in columns.difference(omit)])) if binary.left in cols and binary.right in cols: for c in columns: if c.shares_lineage(binary.right): @@ -444,7 +445,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): self.selectable = selectable self.include = include self.exclude = exclude - self.equivalents = equivalents or {} + self.equivalents = util.column_dict(equivalents or {}) def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) @@ -484,7 +485,7 @@ class ColumnAdapter(ClauseAdapter): ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) if chain_to: self.chain(chain_to) - self.columns = util.PopulateDict(self._locate_col) + self.columns = util.populate_column_dict(self._locate_col) def wrap(self, adapter): ac = self.__class__.__new__(self.__class__) @@ -492,7 +493,7 @@ class ColumnAdapter(ClauseAdapter): ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) - ac.columns = util.PopulateDict(ac._locate_col) + ac.columns = util.populate_column_dict(ac._locate_col) return ac adapt_clause = ClauseAdapter.traverse diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 17b9c59d5..a5bd497ae 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -207,7 +207,7 @@ def traverse_depthfirst(obj, opts, visitors): def cloned_traverse(obj, opts, visitors): """clone the given expression structure, allowing modifications by visitors.""" - cloned = {} + cloned = util.column_dict() def clone(element): if element not in cloned: @@ -234,8 +234,8 @@ def cloned_traverse(obj, opts, visitors): def replacement_traverse(obj, opts, replace): """clone the given expression structure, allowing element replacement by a given replacement function.""" - cloned = {} - stop_on = set(opts.get('stop_on', [])) + cloned = util.column_dict() + stop_on = util.column_set(opts.get('stop_on', [])) def clone(element): newelem = replace(element) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 2604f4e8f..7eda27f4f 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -382,7 +382,7 @@ class Concatenable(object): def adapt_operator(self, op): """Converts an add operator to concat.""" from sqlalchemy.sql import operators - if op == operators.add: + if op is operators.add: return operators.concat_op else: return op diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 8b68fb108..1356fa324 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import inspect, itertools, new, operator, sys, warnings, weakref +import inspect, itertools, operator, sys, warnings, weakref import __builtin__ types = __import__('types') @@ -18,8 +18,13 @@ except ImportError: import dummy_threading as threading from dummy_threading import local as ThreadLocal -if sys.version_info < (2, 6): +py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) + +if py3k: + set_types = set +elif sys.version_info < (2, 6): import sets + set_types = set, sets.Set else: # 2.6 deprecates sets.Set, but we still need to be able to detect them # in user code and as return values from DB-APIs @@ -32,15 +37,24 @@ else: import sets warnings.filters.remove(ignore) -set_types = set, sets.Set + set_types = set, sets.Set EMPTY_SET = frozenset() -try: - import cPickle as pickle -except ImportError: +if py3k: import pickle +else: + try: + import cPickle as pickle + except ImportError: + import pickle +if py3k: + def buffer(x): + return x # no-op until we figure out what MySQLdb is going to use +else: + buffer = __builtin__.buffer + if sys.version_info >= (2, 5): class PopulateDict(dict): """A dict which populates missing values via a creation function. @@ -70,6 +84,17 @@ else: self[key] = value = self.creator(key) return value +if py3k: + def callable(fn): + return hasattr(fn, '__call__') +else: + callable = __builtin__.callable + +if py3k: + from functools import reduce +else: + reduce = __builtin__.reduce + try: from collections import defaultdict except ImportError: @@ -125,6 +150,14 @@ def to_set(x): else: return x +def to_column_set(x): + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + try: from functools import update_wrapper @@ -823,10 +856,11 @@ class IdentitySet(object): This strategy has edge cases for builtin types- it's possible to have two 'foo' strings in one of these sets, for example. Use sparingly. + """ _working_set = set - + def __init__(self, iterable=None): self._members = dict() if iterable: @@ -918,7 +952,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).union(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).union(_iter_id(iterable))) return result def __or__(self, other): @@ -939,7 +973,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).difference(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).difference(_iter_id(iterable))) return result def __sub__(self, other): @@ -960,7 +994,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).intersection(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable))) return result def __and__(self, other): @@ -981,9 +1015,12 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable))) return result - + + def _member_id_tuples(self): + return ((id(v), v) for v in self._members.itervalues()) + def __xor__(self, other): if not isinstance(other, IdentitySet): return NotImplemented @@ -1016,11 +1053,6 @@ class IdentitySet(object): return '%s(%r)' % (type(self).__name__, self._members.values()) -def _iter_id(iterable): - """Generator: ((id(o), o) for o in iterable).""" - for item in iterable: - yield id(item), item - class OrderedIdentitySet(IdentitySet): class _working_set(OrderedSet): # a testing pragma: exempt the OIDS working set from the test suite's @@ -1028,7 +1060,7 @@ class OrderedIdentitySet(IdentitySet): # but it's safe here: IDS operates on (id, instance) tuples in the # working set. __sa_hash_exempt__ = True - + def __init__(self, iterable=None): IdentitySet.__init__(self) self._members = OrderedDict() @@ -1036,6 +1068,19 @@ class OrderedIdentitySet(IdentitySet): for o in iterable: self.add(o) +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + + for item in iterable: + yield id(item), item + +# define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet +populate_column_dict = PopulateDict + def unique_list(seq, compare_with=set): seen = compare_with() return [x for x in seq if x not in seen and not seen.add(x)] @@ -1296,7 +1341,7 @@ def function_named(fn, name): try: fn.__name__ = name except TypeError: - fn = new.function(fn.func_code, fn.func_globals, name, + fn = types.FunctionType(fn.func_code, fn.func_globals, name, fn.func_defaults, fn.func_closure) return fn |