summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-11-16 15:53:14 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2010-11-16 15:53:14 -0500
commit041a329e69f6aa60bdd2f3fb87b5172481806c4a (patch)
tree30df0342ad2499d30964c70c6f2bddb9dfede714
parent40d5a32e59a49075129211358f00e857dac73885 (diff)
downloadsqlalchemy-041a329e69f6aa60bdd2f3fb87b5172481806c4a.tar.gz
- adapt initial patch from [ticket:1917] to current tip
- raise TypeError for immutability
-rw-r--r--lib/sqlalchemy/orm/mapper.py4
-rw-r--r--lib/sqlalchemy/orm/query.py8
-rw-r--r--lib/sqlalchemy/schema.py19
-rw-r--r--lib/sqlalchemy/sql/expression.py123
-rw-r--r--lib/sqlalchemy/sql/operators.py2
-rw-r--r--lib/sqlalchemy/util.py92
-rw-r--r--test/engine/test_metadata.py27
-rw-r--r--test/engine/test_reflection.py1
-rw-r--r--test/sql/test_compiler.py2
9 files changed, 172 insertions, 106 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index e9da4f533..8abb26fb6 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -162,7 +162,7 @@ class Mapper(object):
else:
self.with_polymorphic = None
- if isinstance(self.local_table, expression._SelectBaseMixin):
+ if isinstance(self.local_table, expression._SelectBase):
raise sa_exc.InvalidRequestError(
"When mapping against a select() construct, map against "
"an alias() of the construct instead."
@@ -172,7 +172,7 @@ class Mapper(object):
if self.with_polymorphic and \
isinstance(self.with_polymorphic[1],
- expression._SelectBaseMixin):
+ expression._SelectBase):
self.with_polymorphic = (self.with_polymorphic[0],
self.with_polymorphic[1].alias())
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 2bccb8f73..2f482537d 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -162,7 +162,7 @@ class Query(object):
fa = []
for from_obj in obj:
- if isinstance(from_obj, expression._SelectBaseMixin):
+ if isinstance(from_obj, expression._SelectBase):
from_obj = from_obj.alias()
fa.append(from_obj)
@@ -1597,7 +1597,7 @@ class Query(object):
if not isinstance(statement,
(expression._TextClause,
- expression._SelectBaseMixin)):
+ expression._SelectBase)):
raise sa_exc.ArgumentError(
"from_statement accepts text(), select(), "
"and union() objects only.")
@@ -2468,7 +2468,7 @@ class Query(object):
for hint in self._with_hints:
statement = statement.with_hint(*hint)
-
+
if self._execution_options:
statement = statement.execution_options(
**self._execution_options)
@@ -2803,7 +2803,7 @@ class QueryContext(object):
def __init__(self, query):
if query._statement is not None:
- if isinstance(query._statement, expression._SelectBaseMixin) and \
+ if isinstance(query._statement, expression._SelectBase) and \
not query._statement.use_labels:
self.statement = query._statement.apply_labels()
else:
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index a332cec36..e7a5d6e46 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -227,7 +227,7 @@ class Table(SchemaItem, expression.TableClause):
self.constraints = set()
self._columns = expression.ColumnCollection()
self._set_primary_key(PrimaryKeyConstraint())
- self._foreign_keys = util.OrderedSet()
+ self.foreign_keys = util.OrderedSet()
self._extra_dependencies = set()
self.ddl_listeners = util.defaultdict(list)
self.kwargs = {}
@@ -283,7 +283,7 @@ class Table(SchemaItem, expression.TableClause):
if include_columns:
for c in self.c:
if c.name not in include_columns:
- self.c.remove(c)
+ self._columns.remove(c)
for key in ('quote', 'quote_schema'):
if key in kwargs:
@@ -307,10 +307,13 @@ class Table(SchemaItem, expression.TableClause):
"Invalid argument(s) for Table: %r" % kwargs.keys())
self.kwargs.update(kwargs)
+ def _init_collections(self):
+ pass
+
def _set_primary_key(self, pk):
- if getattr(self, '_primary_key', None) in self.constraints:
- self.constraints.remove(self._primary_key)
- self._primary_key = pk
+ if self.primary_key in self.constraints:
+ self.constraints.remove(self.primary_key)
+ self.primary_key = pk
self.constraints.add(pk)
for c in pk.columns:
@@ -330,10 +333,6 @@ class Table(SchemaItem, expression.TableClause):
def key(self):
return _get_table_key(self.name, self.schema)
- @property
- def primary_key(self):
- return self._primary_key
-
def __repr__(self):
return "Table(%s)" % ', '.join(
[repr(self.name)] + [repr(self.metadata)] +
@@ -937,7 +936,7 @@ class Column(SchemaItem, expression.ColumnClause):
nullable = self.nullable,
quote=self.quote, _proxies=[self], *fk)
c.table = selectable
- selectable.columns.add(c)
+ selectable._columns.add(c)
if self.primary_key:
selectable.primary_key.add(c)
for fn in c._table_events:
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index c3dc339a5..0f93643dc 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1658,7 +1658,7 @@ class _CompareMixin(ColumnOperators):
if isinstance(seq_or_selectable, _ScalarSelect):
return self.__compare(op, seq_or_selectable,
negate=negate_op)
- elif isinstance(seq_or_selectable, _SelectBaseMixin):
+ elif isinstance(seq_or_selectable, _SelectBase):
# TODO: if we ever want to support (x, y, z) IN (select x,
# y, z from table), we would need a multi-column version of
@@ -1830,7 +1830,7 @@ class _CompareMixin(ColumnOperators):
return other.__clause_element__()
elif not isinstance(other, ClauseElement):
return self._bind_param(operator, other)
- elif isinstance(other, (_SelectBaseMixin, Alias)):
+ elif isinstance(other, (_SelectBase, Alias)):
return other.as_scalar()
else:
return other
@@ -1905,7 +1905,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
co = ColumnClause(name, selectable, type_=getattr(self,
'type', None))
co.proxies = [self]
- selectable.columns[key] = co
+ selectable._columns[key] = co
return co
def compare(self, other, use_proxies=False, equivalents=None, **kw):
@@ -2044,6 +2044,16 @@ class ColumnCollection(util.OrderedProperties):
# always return a "True" value (i.e. a BinaryClause...)
return col in util.column_set(self)
+
+ def as_immutable(self):
+ return ImmutableColumnCollection(self._data)
+
+class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
+ def __init__(self, data):
+ util.ImmutableProperties.__init__(self, data)
+
+ extend = remove = util.ImmutableProperties._immutable
+
class ColumnSet(util.ordered_column_set):
def contains_column(self, col):
@@ -2239,44 +2249,50 @@ class FromClause(Selectable):
def _reset_exported(self):
"""delete memoized collections when a FromClause is cloned."""
- for attr in '_columns', '_primary_key', '_foreign_keys', \
- 'locate_all_froms':
- self.__dict__.pop(attr, None)
+ for name in 'primary_key', '_columns', 'columns', \
+ 'foreign_keys', 'locate_all_froms':
+ self.__dict__.pop(name, None)
@util.memoized_property
- def _columns(self):
+ def columns(self):
"""Return the collection of Column objects contained by this
FromClause."""
-
- self._export_columns()
- return self._columns
-
+
+ if '_columns' not in self.__dict__:
+ self._init_collections()
+ self._populate_column_collection()
+ return self._columns.as_immutable()
+
@util.memoized_property
- def _primary_key(self):
+ def primary_key(self):
"""Return the collection of Column objects which comprise the
primary key of this FromClause."""
-
- self._export_columns()
- return self._primary_key
-
+
+ self._init_collections()
+ self._populate_column_collection()
+ return self.primary_key
+
@util.memoized_property
- def _foreign_keys(self):
+ def foreign_keys(self):
"""Return the collection of ForeignKey objects which this
FromClause references."""
+
+ self._init_collections()
+ self._populate_column_collection()
+ return self.foreign_keys
- self._export_columns()
- return self._foreign_keys
- columns = property(attrgetter('_columns'), doc=_columns.__doc__)
- primary_key = property(attrgetter('_primary_key'),
- doc=_primary_key.__doc__)
- foreign_keys = property(attrgetter('_foreign_keys'),
- doc=_foreign_keys.__doc__)
-
- # synonyms for 'columns'
-
- c = _select_iterable = property(attrgetter('columns'),
- doc=_columns.__doc__)
-
+ c = property(attrgetter('columns'))
+ _select_iterable = property(attrgetter('columns'))
+
+ def _init_collections(self):
+ assert '_columns' not in self.__dict__
+ assert 'primary_key' not in self.__dict__
+ assert 'foreign_keys' not in self.__dict__
+
+ self._columns = ColumnCollection()
+ self.primary_key = ColumnSet()
+ self.foreign_keys = set()
+
def _export_columns(self):
"""Initialize column collections."""
@@ -3009,7 +3025,7 @@ class _Exists(_UnaryExpression):
_from_objects = []
def __init__(self, *args, **kwargs):
- if args and isinstance(args[0], (_SelectBaseMixin, _ScalarSelect)):
+ if args and isinstance(args[0], (_SelectBase, _ScalarSelect)):
s = args[0]
else:
if not args:
@@ -3088,10 +3104,10 @@ class Join(FromClause):
columns = [c for c in self.left.columns] + \
[c for c in self.right.columns]
- self._primary_key.extend(sqlutil.reduce_columns(
+ self.primary_key.extend(sqlutil.reduce_columns(
(c for c in columns if c.primary_key), self.onclause))
self._columns.update((col._label, col) for col in columns)
- self._foreign_keys.update(itertools.chain(
+ self.foreign_keys.update(itertools.chain(
*[col.foreign_keys for col in columns]))
def _copy_internals(self, clone=_clone):
@@ -3281,12 +3297,26 @@ class _FromGrouping(FromClause):
def __init__(self, element):
self.element = element
-
+
+ def _init_collections(self):
+ pass
+
@property
def columns(self):
return self.element.columns
@property
+ def primary_key(self):
+ return self.element.primary_key
+
+ @property
+ def foreign_keys(self):
+ # this could be
+ # self.element.foreign_keys
+ # see SelectableTest.test_join_condition
+ return set()
+
+ @property
def _hide_froms(self):
return self.element._hide_froms
@@ -3476,7 +3506,7 @@ class ColumnClause(_Immutable, ColumnElement):
)
c.proxies = [self]
if attach:
- selectable.columns[c.name] = c
+ selectable._columns[c.name] = c
return c
class TableClause(_Immutable, FromClause):
@@ -3496,11 +3526,14 @@ class TableClause(_Immutable, FromClause):
super(TableClause, self).__init__()
self.name = self.fullname = name
self._columns = ColumnCollection()
- self._primary_key = ColumnSet()
- self._foreign_keys = set()
+ self.primary_key = ColumnSet()
+ self.foreign_keys = set()
for c in columns:
self.append_column(c)
-
+
+ def _init_collections(self):
+ pass
+
def _export_columns(self):
raise NotImplementedError()
@@ -3556,7 +3589,7 @@ class TableClause(_Immutable, FromClause):
def _from_objects(self):
return [self]
-class _SelectBaseMixin(Executable):
+class _SelectBase(Executable, FromClause):
"""Base class for :class:`Select` and ``CompoundSelects``."""
def __init__(self,
@@ -3583,7 +3616,7 @@ class _SelectBaseMixin(Executable):
self._order_by_clause = ClauseList(*util.to_list(order_by) or [])
self._group_by_clause = ClauseList(*util.to_list(group_by) or [])
-
+
def as_scalar(self):
"""return a 'scalar' representation of this selectable, which can be
used as a column expression.
@@ -3729,7 +3762,7 @@ class _ScalarSelect(_Grouping):
def _make_proxy(self, selectable, name):
return list(self.inner_columns)[0]._make_proxy(selectable, name)
-class CompoundSelect(_SelectBaseMixin, FromClause):
+class CompoundSelect(_SelectBase):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
SELECT-based set operations."""
@@ -3764,7 +3797,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
self.selects.append(s.self_group(self))
- _SelectBaseMixin.__init__(self, **kwargs)
+ _SelectBase.__init__(self, **kwargs)
def _scalar_type(self):
return self.selects[0]._scalar_type()
@@ -3830,7 +3863,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
self._bind = bind
bind = property(bind, _set_bind)
-class Select(_SelectBaseMixin, FromClause):
+class Select(_SelectBase):
"""Represents a ``SELECT`` statement.
Select statements support appendable clauses, as well as the
@@ -3859,7 +3892,7 @@ class Select(_SelectBaseMixin, FromClause):
argument descriptions.
Additional generative and mutator methods are available on the
- :class:`_SelectBaseMixin` superclass.
+ :class:`_SelectBase` superclass.
"""
self._should_correlate = correlate
@@ -3907,7 +3940,7 @@ class Select(_SelectBaseMixin, FromClause):
if prefixes:
self._prefixes = tuple([_literal_as_text(p) for p in prefixes])
- _SelectBaseMixin.__init__(self, **kwargs)
+ _SelectBase.__init__(self, **kwargs)
def _get_display_froms(self, existing_froms=None):
"""Return the full list of 'from' clauses to be displayed.
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 6f70b1778..67830f7cf 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -83,7 +83,7 @@ def desc_op(a):
def asc_op(a):
return a.asc()
-_commutative = set([eq, ne, add, mul])
+_commutative = set([eq, ne, add, mul, and_])
def is_commutative(op):
return op in _commutative
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 8665cd0d4..cfeb38f54 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -146,36 +146,6 @@ except ImportError:
return 'defaultdict(%s, %s)' % (self.default_factory,
dict.__repr__(self))
-class frozendict(dict):
- @property
- def _blocked_attribute(obj):
- raise AttributeError, "A frozendict cannot be modified."
-
- __delitem__ = __setitem__ = clear = _blocked_attribute
- pop = popitem = setdefault = update = _blocked_attribute
-
- def __new__(cls, *args):
- new = dict.__new__(cls)
- dict.__init__(new, *args)
- return new
-
- def __init__(self, *args):
- pass
-
- def __reduce__(self):
- return frozendict, (dict(self), )
-
- def union(self, d):
- if not self:
- return frozendict(d)
- else:
- d2 = self.copy()
- d2.update(d)
- return frozendict(d2)
-
- def __repr__(self):
- return "frozendict(%s)" % dict.__repr__(self)
-
# find or create a dict implementation that supports __missing__
class _probe(dict):
@@ -759,20 +729,44 @@ class NamedTuple(tuple):
def keys(self):
return [l for l in self._labels if l is not None]
+class ImmutableContainer(object):
+ def _immutable(self, *arg, **kw):
+ raise TypeError("%s object is immutable" % self.__class__.__name__)
-class OrderedProperties(object):
- """An object that maintains the order in which attributes are set upon it.
+ __delitem__ = __setitem__ = __setattr__ = _immutable
- Also provides an iterator and a very basic getitem/setitem
- interface to those attributes.
+class frozendict(ImmutableContainer, dict):
+
+ clear = pop = popitem = setdefault = \
+ update = ImmutableContainer._immutable
- (Not really a dict, since it iterates over values, not keys. Not really
- a list, either, since each value must have a key associated; hence there is
- no append or extend.)
- """
+ def __new__(cls, *args):
+ new = dict.__new__(cls)
+ dict.__init__(new, *args)
+ return new
- def __init__(self):
- self.__dict__['_data'] = OrderedDict()
+ def __init__(self, *args):
+ pass
+
+ def __reduce__(self):
+ return frozendict, (dict(self), )
+
+ def union(self, d):
+ if not self:
+ return frozendict(d)
+ else:
+ d2 = self.copy()
+ d2.update(d)
+ return frozendict(d2)
+
+ def __repr__(self):
+ return "frozendict(%s)" % dict.__repr__(self)
+
+class Properties(object):
+ """Provide a __getattr__/__setattr__ interface over a dict."""
+
+ def __init__(self, data):
+ self.__dict__['_data'] = data
def __len__(self):
return len(self._data)
@@ -809,7 +803,12 @@ class OrderedProperties(object):
def __contains__(self, key):
return key in self._data
-
+
+ def as_immutable(self):
+ """Return an immutable proxy for this :class:`.Properties`."""
+
+ return ImmutableProperties(self._data)
+
def update(self, value):
self._data.update(value)
@@ -828,6 +827,17 @@ class OrderedProperties(object):
def clear(self):
self._data.clear()
+class OrderedProperties(Properties):
+ """Provide a __getattr__/__setattr__ interface with an OrderedDict
+ as backing store."""
+ def __init__(self):
+ Properties.__init__(self, OrderedDict())
+
+
+class ImmutableProperties(ImmutableContainer, Properties):
+ """Provide immutable dict/object attribute to an underlying dictionary."""
+
+
class OrderedDict(dict):
"""A dict that returns keys/values/items in the order they were added."""
diff --git a/test/engine/test_metadata.py b/test/engine/test_metadata.py
index b2250c808..b3a9cef2e 100644
--- a/test/engine/test_metadata.py
+++ b/test/engine/test_metadata.py
@@ -353,7 +353,6 @@ class MetaDataTest(TestBase, ComparesTables):
[d, b, a, c, e]
)
-
def test_tometadata_strip_schema(self):
meta = MetaData()
@@ -387,7 +386,7 @@ class MetaDataTest(TestBase, ComparesTables):
MetaData(testing.db), autoload=True)
-class TableOptionsTest(TestBase, AssertsCompiledSQL):
+class TableTest(TestBase, AssertsCompiledSQL):
def test_prefixes(self):
table1 = Table("temporary_table_1", MetaData(),
Column("col1", Integer),
@@ -418,3 +417,27 @@ class TableOptionsTest(TestBase, AssertsCompiledSQL):
t.info['bar'] = 'zip'
assert t.info['bar'] == 'zip'
+ def test_c_immutable(self):
+ m = MetaData()
+ t1 = Table('t', m, Column('x', Integer), Column('y', Integer))
+ assert_raises(
+ TypeError,
+ t1.c.extend, [Column('z', Integer)]
+ )
+
+ def assign():
+ t1.c['z'] = Column('z', Integer)
+ assert_raises(
+ TypeError,
+ assign
+ )
+
+ def assign():
+ t1.c.z = Column('z', Integer)
+ assert_raises(
+ TypeError,
+ assign
+ )
+
+
+ \ No newline at end of file
diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py
index d0d6e31e1..91e73d4f2 100644
--- a/test/engine/test_reflection.py
+++ b/test/engine/test_reflection.py
@@ -611,6 +611,7 @@ class ReflectionTest(TestBase, ComparesTables):
self.assert_tables_equal(multi, table)
self.assert_tables_equal(multi2, table2)
j = sa.join(table, table2)
+
self.assert_(sa.and_(table.c.multi_id == table2.c.foo,
table.c.multi_rev == table2.c.bar,
table.c.multi_hoho
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 338a5491e..2d6f0e104 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -1553,7 +1553,7 @@ class SelectTest(TestBase, AssertsCompiledSQL):
"SELECT foo, bar FROM bat UNION SELECT foo, bar "
"FROM bat UNION SELECT foo, bar FROM bat UNION SELECT foo, bar FROM bat"
)
-
+
self.assert_compile(
union(s, union(s, union(s, s))),
"SELECT foo, bar FROM bat UNION (SELECT foo, bar FROM bat "