summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES7
-rw-r--r--lib/sqlalchemy/orm/mapper.py12
-rw-r--r--lib/sqlalchemy/orm/strategies.py6
-rw-r--r--lib/sqlalchemy/orm/util.py52
-rw-r--r--test/orm/test_mapper.py64
5 files changed, 112 insertions, 29 deletions
diff --git a/CHANGES b/CHANGES
index f4d46f673..161668e82 100644
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,13 @@ CHANGES
directives in statements. Courtesy
Diana Clarke [ticket:2443]
+ - [feature] Added new flag to @validates
+ include_removes. When True, collection
+ remove and attribute del events
+ will also be sent to the validation function,
+ which accepts an additional argument
+ "is_remove" when this flag is used.
+
- [bug] Fixed bug whereby polymorphic_on
column that's not otherwise mapped on the
class would be incorrectly included
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index e96b7549a..afabac05a 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -678,9 +678,10 @@ class Mapper(object):
self._reconstructor = method
event.listen(manager, 'load', _event_on_load, raw=True)
elif hasattr(method, '__sa_validators__'):
+ include_removes = getattr(method, "__sa_include_removes__", False)
for name in method.__sa_validators__:
self.validators = self.validators.union(
- {name : method}
+ {name : (method, include_removes)}
)
manager.info[_INSTRUMENTOR] = self
@@ -2291,7 +2292,7 @@ def reconstructor(fn):
fn.__sa_reconstructor__ = True
return fn
-def validates(*names):
+def validates(*names, **kw):
"""Decorate a method as a 'validator' for one or more named properties.
Designates a method as a validator, a method which receives the
@@ -2307,9 +2308,16 @@ def validates(*names):
an assertion to avoid recursion overflows. This is a reentrant
condition which is not supported.
+ :param \*names: list of attribute names to be validated.
+ :param include_removes: if True, "remove" events will be
+ sent as well - the validation function must accept an additional
+ argument "is_remove" which will be a boolean. New in 0.7.7.
+
"""
+ include_removes = kw.pop('include_removes', False)
def wrap(fn):
fn.__sa_validators__ = names
+ fn.__sa_include_removes__ = include_removes
return fn
return wrap
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 5f4b182d0..37980e111 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -45,11 +45,11 @@ def _register_attribute(strategy, mapper, useobject,
listen_hooks.append(single_parent_validator)
if prop.key in prop.parent.validators:
+ fn, include_removes = prop.parent.validators[prop.key]
listen_hooks.append(
lambda desc, prop: mapperutil._validator_events(desc,
- prop.key,
- prop.parent.validators[prop.key])
- )
+ prop.key, fn, include_removes)
+ )
if useobject:
listen_hooks.append(unitofwork.track_cascade_events)
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 0c5f203a7..197c0c4c1 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -68,24 +68,36 @@ class CascadeOptions(frozenset):
",".join([x for x in sorted(self)])
)
-def _validator_events(desc, key, validator):
+def _validator_events(desc, key, validator, include_removes):
"""Runs a validation method on an attribute value to be set or appended."""
- def append(state, value, initiator):
- return validator(state.obj(), key, value)
+ if include_removes:
+ def append(state, value, initiator):
+ return validator(state.obj(), key, value, False)
- def set_(state, value, oldvalue, initiator):
- return validator(state.obj(), key, value)
+ def set_(state, value, oldvalue, initiator):
+ return validator(state.obj(), key, value, False)
+
+ def remove(state, value, initiator):
+ validator(state.obj(), key, value, True)
+ else:
+ def append(state, value, initiator):
+ return validator(state.obj(), key, value)
+
+ def set_(state, value, oldvalue, initiator):
+ return validator(state.obj(), key, value)
event.listen(desc, 'append', append, raw=True, retval=True)
event.listen(desc, 'set', set_, raw=True, retval=True)
+ if include_removes:
+ event.listen(desc, "remove", remove, raw=True, retval=True)
def polymorphic_union(table_map, typecolname, aliasname='p_union', cast_nulls=True):
"""Create a ``UNION`` statement used by a polymorphic mapper.
See :ref:`concrete_inheritance` for an example of how
this is used.
-
+
:param table_map: mapping of polymorphic identities to
:class:`.Table` objects.
:param typecolname: string name of a "discriminator" column, which will be
@@ -236,7 +248,7 @@ class AliasedClass(object):
session.query(User, user_alias).\\
join((user_alias, User.id > user_alias.id)).\\
filter(User.name==user_alias.name)
-
+
The resulting object is an instance of :class:`.AliasedClass`, however
it implements a ``__getattribute__()`` scheme which will proxy attribute
access to that of the ORM class being aliased. All classmethods
@@ -244,7 +256,7 @@ class AliasedClass(object):
hybrids created with the :ref:`hybrids_toplevel` extension,
which will receive the :class:`.AliasedClass` as the "class" argument
when classmethods are called.
-
+
:param cls: ORM mapped entity which will be "wrapped" around an alias.
:param alias: a selectable, such as an :func:`.alias` or :func:`.select`
construct, which will be rendered in place of the mapped table of the
@@ -259,28 +271,28 @@ class AliasedClass(object):
otherwise have a column that corresponds to one on the entity. The
use case for this is when associating an entity with some derived
selectable such as one that uses aggregate functions::
-
+
class UnitPrice(Base):
__tablename__ = 'unit_price'
...
unit_id = Column(Integer)
price = Column(Numeric)
-
+
aggregated_unit_price = Session.query(
func.sum(UnitPrice.price).label('price')
).group_by(UnitPrice.unit_id).subquery()
-
+
aggregated_unit_price = aliased(UnitPrice, alias=aggregated_unit_price, adapt_on_names=True)
-
+
Above, functions on ``aggregated_unit_price`` which
refer to ``.price`` will return the
``fund.sum(UnitPrice.price).label('price')`` column,
as it is matched on the name "price". Ordinarily, the "price" function wouldn't
have any "column correspondence" to the actual ``UnitPrice.price`` column
as it is not a proxy of the original.
-
+
``adapt_on_names`` is new in 0.7.3.
-
+
"""
def __init__(self, cls, alias=None, name=None, adapt_on_names=False):
self.__mapper = _class_to_mapper(cls)
@@ -447,7 +459,7 @@ class _ORMJoin(expression.Join):
def join(left, right, onclause=None, isouter=False, join_to_left=True):
"""Produce an inner join between left and right clauses.
-
+
:func:`.orm.join` is an extension to the core join interface
provided by :func:`.sql.expression.join()`, where the
left and right selectables may be not only core selectable
@@ -460,7 +472,7 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True):
in whatever form it is passed, to the selectable
passed as the left side. If False, the onclause
is used as is.
-
+
:func:`.orm.join` is not commonly needed in modern usage,
as its functionality is encapsulated within that of the
:meth:`.Query.join` method, which features a
@@ -468,22 +480,22 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True):
by itself. Explicit usage of :func:`.orm.join`
with :class:`.Query` involves usage of the
:meth:`.Query.select_from` method, as in::
-
+
from sqlalchemy.orm import join
session.query(User).\\
select_from(join(User, Address, User.addresses)).\\
filter(Address.email_address=='foo@bar.com')
-
+
In modern SQLAlchemy the above join can be written more
succinctly as::
-
+
session.query(User).\\
join(User.addresses).\\
filter(Address.email_address=='foo@bar.com')
See :meth:`.Query.join` for information on modern usage
of ORM level joins.
-
+
"""
return _ORMJoin(left, right, onclause, isouter, join_to_left)
diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py
index 1c5f29b71..79ae7ff59 100644
--- a/test/orm/test_mapper.py
+++ b/test/orm/test_mapper.py
@@ -1950,10 +1950,11 @@ class DeepOptionsTest(_fixtures.FixtureTest):
class ValidatorTest(_fixtures.FixtureTest):
def test_scalar(self):
users = self.tables.users
-
+ canary = []
class User(fixtures.ComparableEntity):
@validates('name')
def validate_name(self, key, name):
+ canary.append((key, name))
assert name != 'fred'
return name + ' modified'
@@ -1963,6 +1964,7 @@ class ValidatorTest(_fixtures.FixtureTest):
eq_(u1.name, 'ed modified')
assert_raises(AssertionError, setattr, u1, "name", "fred")
eq_(u1.name, 'ed modified')
+ eq_(canary, [('name', 'ed'), ('name', 'fred')])
sess.add(u1)
sess.flush()
sess.expunge_all()
@@ -1973,9 +1975,11 @@ class ValidatorTest(_fixtures.FixtureTest):
self.tables.addresses,
self.classes.Address)
+ canary = []
class User(fixtures.ComparableEntity):
@validates('addresses')
def validate_address(self, key, ad):
+ canary.append((key, ad))
assert '@' in ad.email_address
return ad
@@ -1983,8 +1987,11 @@ class ValidatorTest(_fixtures.FixtureTest):
mapper(Address, addresses)
sess = create_session()
u1 = User(name='edward')
- assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail'))
- u1.addresses.append(Address(id=15, email_address='foo@bar.com'))
+ a0 = Address(email_address='noemail')
+ assert_raises(AssertionError, u1.addresses.append, a0)
+ a1 = Address(id=15, email_address='foo@bar.com')
+ u1.addresses.append(a1)
+ eq_(canary, [('addresses', a0), ('addresses', a1)])
sess.add(u1)
sess.flush()
sess.expunge_all()
@@ -2019,11 +2026,60 @@ class ValidatorTest(_fixtures.FixtureTest):
mapper(Address, addresses)
eq_(
- dict((k, v.__name__) for k, v in u_m.validators.items()),
+ dict((k, v[0].__name__) for k, v in u_m.validators.items()),
{'name':'validate_name',
'addresses':'validate_address'}
)
+ def test_validator_w_removes(self):
+ users, addresses, Address = (self.tables.users,
+ self.tables.addresses,
+ self.classes.Address)
+ canary = []
+ class User(fixtures.ComparableEntity):
+
+ @validates('name', include_removes=True)
+ def validate_name(self, key, item, remove):
+ canary.append((key, item, remove))
+ return item
+
+ @validates('addresses', include_removes=True)
+ def validate_address(self, key, item, remove):
+ canary.append((key, item, remove))
+ return item
+
+ mapper(User,
+ users,
+ properties={'addresses':relationship(Address)})
+ mapper(Address, addresses)
+
+ u1 = User()
+ u1.name = "ed"
+ u1.name = "mary"
+ del u1.name
+ a1, a2, a3 = Address(), Address(), Address()
+ u1.addresses.append(a1)
+ u1.addresses.remove(a1)
+ u1.addresses = [a1, a2]
+ u1.addresses = [a2, a3]
+
+ eq_(canary, [
+ ('name', 'ed', False),
+ ('name', 'mary', False),
+ ('name', 'mary', True),
+ # append a1
+ ('addresses', a1, False),
+ # remove a1
+ ('addresses', a1, True),
+ # set to [a1, a2] - this is two appends
+ ('addresses', a1, False), ('addresses', a2, False),
+ # set to [a2, a3] - this is a remove of a1,
+ # append of a3. the appends are first.
+ ('addresses', a3, False),
+ ('addresses', a1, True),
+ ]
+ )
+
class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL):
def test_kwarg_accepted(self):
users, Address = self.tables.users, self.classes.Address