diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-03-22 12:56:23 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-03-22 17:44:56 -0400 |
commit | 1fcbc17b7dd5a5cad71ee79441aa3293c00b8877 (patch) | |
tree | c87cb20c6fd7a9b6220b1cb3f27faac7d6d03261 /lib/sqlalchemy | |
parent | 28edc2604a96d5ecd8318232c95a034433aa07d1 (diff) | |
download | sqlalchemy-1fcbc17b7dd5a5cad71ee79441aa3293c00b8877.tar.gz |
Support hybrids/composites with bulk updates
The :meth:`.Query.update` method can now accommodate both
hybrid attributes as well as composite attributes as a source
of the key to be placed in the SET clause. For hybrids, an
additional decorator :meth:`.hybrid_property.update_expression`
is supplied for which the user supplies a tuple-returning function.
Change-Id: I15e97b02381d553f30b3301308155e19128d2cfb
Fixes: #3229
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/ext/hybrid.py | 109 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 109 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 4 |
6 files changed, 200 insertions, 56 deletions
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 17049d995..141a64599 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -183,6 +183,62 @@ The ``length(self, value)`` method is now called upon set:: >>> i1.end 17 +.. _hybrid_bulk_update: + +Allowing Bulk ORM Update +------------------------ + +A hybrid can define a custom "UPDATE" handler for when using the +:meth:`.Query.update` method, allowing the hybrid to be used in the +SET clause of the update. + +Normally, when using a hybrid with :meth:`.Query.update`, the SQL +expression is used as the column that's the target of the SET. If our +``Interval`` class had a hybrid ``start_point`` that linked to +``Interval.start``, this could be substituted directly:: + + session.query(Interval).update({Interval.start_point: 10}) + +However, when using a composite hybrid like ``Interval.length``, this +hybrid represents more than one column. We can set up a handler that will +accommodate a value passed to :meth:`.Query.update` which can affect +this, using the :meth:`.hybrid_propery.update_expression` decorator. +A handler that works similarly to our setter would be:: + + class Interval(object): + # ... + + @hybrid_property + def length(self): + return self.end - self.start + + @length.setter + def length(self, value): + self.end = self.start + value + + @length.update_expression + def length(cls, value): + return [ + (cls.end, cls.start + value) + ] + +Above, if we use ``Interval.length`` in an UPDATE expression as:: + + session.query(Interval).update( + {Interval.length: 25}, synchronize_session='fetch') + +We'll get an UPDATE statement along the lines of:: + + UPDATE interval SET end=start + :value + +In some cases, the default "evaluate" strategy can't perform the SET +expression in Python; while the addition operator we're using above +is supported, for more complex SET expressions it will usually be necessary +to use either the "fetch" or False synchronization strategy as illustrated +above. + +.. versionadded:: 1.2 added support for bulk updates to hybrid properties. + Working with Relationships -------------------------- @@ -777,7 +833,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): def __init__( self, fget, fset=None, fdel=None, - expr=None, custom_comparator=None): + expr=None, custom_comparator=None, update_expr=None): """Create a new :class:`.hybrid_property`. Usage is typically via decorator:: @@ -799,7 +855,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): self.fdel = fdel self.expr = expr self.custom_comparator = custom_comparator - + self.update_expr = update_expr util.update_wrapper(self, fget) def __get__(self, instance, owner): @@ -940,6 +996,42 @@ class hybrid_property(interfaces.InspectionAttrInfo): """ return self._copy(custom_comparator=comparator) + def update_expression(self, meth): + """Provide a modifying decorator that defines an UPDATE tuple + producing method. + + The method accepts a single value, which is the value to be + rendered into the SET clause of an UPDATE statement. The method + should then process this value into individual column expressions + that fit into the ultimate SET clause, and return them as a + sequence of 2-tuples. Each tuple + contains a column expression as the key and a value to be rendered. + + E.g.:: + + class Person(Base): + # ... + + first_name = Column(String) + last_name = Column(String) + + @hybrid_property + def fullname(self): + return first_name + " " + last_name + + @fullname.update_expression + def fullname(cls, value): + fname, lname = value.split(" ", 1) + return [ + (cls.first_name, fname), + (cls.last_name, lname) + ] + + .. versionadded:: 1.2 + + """ + return self._copy(update_expr=meth) + @util.memoized_property def _expr_comparator(self): if self.custom_comparator is not None: @@ -952,7 +1044,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): def _get_expr(self, expr): def _expr(cls): - return ExprComparator(expr(cls), self) + return ExprComparator(cls, expr(cls), self) util.update_wrapper(_expr, expr) return self._get_comparator(_expr) @@ -990,7 +1082,8 @@ class Comparator(interfaces.PropComparator): class ExprComparator(Comparator): - def __init__(self, expression, hybrid): + def __init__(self, cls, expression, hybrid): + self.cls = cls self.expression = expression self.hybrid = hybrid @@ -1001,6 +1094,14 @@ class ExprComparator(Comparator): def info(self): return self.hybrid.info + def _bulk_update_tuples(self, value): + if isinstance(self.expression, attributes.QueryableAttribute): + return self.expression._bulk_update_tuples(value) + elif self.hybrid.update_expr is not None: + return self.hybrid.update_expr(self.cls, value) + else: + return [(self.expression, value)] + @property def property(self): return self.expression.property diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index a387e7d76..23a9f1a8c 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -151,6 +151,11 @@ class QueryableAttribute(interfaces._MappedAttribute, return self.comparator._query_clause_element() + def _bulk_update_tuples(self, value): + """Return setter tuples for a bulk UPDATE.""" + + return self.comparator._bulk_update_tuples(value) + def adapt_to_entity(self, adapt_to_entity): assert not self._of_type return self.__class__(adapt_to_entity.entity, diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 0792ff2e2..9afdbf693 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -278,9 +278,15 @@ class CompositeProperty(DescriptorProperty): """Establish events that populate/expire the composite attribute.""" def load_handler(state, *args): + _load_refresh_handler(state, args, is_refresh=False) + + def refresh_handler(state, *args): + _load_refresh_handler(state, args, is_refresh=True) + + def _load_refresh_handler(state, args, is_refresh): dict_ = state.dict - if self.key in dict_: + if not is_refresh and self.key in dict_: return # if column elements aren't loaded, skip. @@ -290,7 +296,6 @@ class CompositeProperty(DescriptorProperty): if k not in dict_: return - # assert self.key not in dict_ dict_[self.key] = self.composite_class( *[state.dict[key] for key in self._attribute_keys] @@ -317,7 +322,7 @@ class CompositeProperty(DescriptorProperty): event.listen(self.parent, 'load', load_handler, raw=True, propagate=True) event.listen(self.parent, 'refresh', - load_handler, raw=True, propagate=True) + refresh_handler, raw=True, propagate=True) event.listen(self.parent, 'expire', expire_handler, raw=True, propagate=True) @@ -411,6 +416,21 @@ class CompositeProperty(DescriptorProperty): return CompositeProperty.CompositeBundle( self.prop, self.__clause_element__()) + def _bulk_update_tuples(self, value): + if value is None: + values = [None for key in self.prop._attribute_keys] + elif isinstance(value, self.prop.composite_class): + values = value.__composite_values__() + else: + raise sa_exc.ArgumentError( + "Can't UPDATE composite attribute %s to %r" % + (self.prop, value)) + + return zip( + self._comparable_elements, + values + ) + @util.memoized_property def _comparable_elements(self): if self._adapt_to_entity: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index fbe8f503e..1b14acefb 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -348,6 +348,9 @@ class PropComparator(operators.ColumnOperators): def _query_clause_element(self): return self.__clause_element__() + def _bulk_update_tuples(self, value): + return [(self.__clause_element__(), value)] + def adapt_to_entity(self, adapt_to_entity): """Return a copy of this PropComparator which will use the given :class:`.AliasedInsp` to produce corresponding expressions. diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 8e91dd6c7..5dc5a90b1 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -18,7 +18,7 @@ import operator from itertools import groupby, chain from .. import sql, util, exc as sa_exc from . import attributes, sync, exc as orm_exc, evaluator -from .base import state_str, _attr_as_key, _entity_descriptor +from .base import state_str, _entity_descriptor from ..sql import expression from ..sql.base import _from_objects from . import loading @@ -1180,6 +1180,12 @@ class BulkUD(object): self._do_post_synchronize() self._do_post() + def _execute_stmt(self, stmt): + self.result = self.query.session.execute( + stmt, params=self.query._params, + mapper=self.mapper) + self.rowcount = self.result.rowcount + @util.dependencies("sqlalchemy.orm.query") def _do_pre(self, querylib): query = self.query @@ -1287,41 +1293,49 @@ class BulkUpdate(BulkUD): False: BulkUpdate }, synchronize_session, query, values, update_kwargs) - def _resolve_string_to_expr(self, key): - if self.mapper and isinstance(key, util.string_types): - attr = _entity_descriptor(self.mapper, key) - return attr.__clause_element__() - else: - return key - - def _resolve_key_to_attrname(self, key): - if self.mapper and isinstance(key, util.string_types): - attr = _entity_descriptor(self.mapper, key) - return attr.property.key - elif isinstance(key, attributes.InstrumentedAttribute): - return key.key - elif hasattr(key, '__clause_element__'): - key = key.__clause_element__() - - if self.mapper and isinstance(key, expression.ColumnElement): - try: - attr = self.mapper._columntoproperty[key] - except orm_exc.UnmappedColumnError: - return None + @property + def _resolved_values(self): + values = [] + for k, v in ( + self.values.items() if hasattr(self.values, 'items') + else self.values): + if self.mapper: + if isinstance(k, util.string_types): + desc = _entity_descriptor(self.mapper, k) + values.extend(desc._bulk_update_tuples(v)) + elif isinstance(k, attributes.QueryableAttribute): + values.extend(k._bulk_update_tuples(v)) + else: + values.append((k, v)) else: - return attr.key - else: - raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % key) + values.append((k, v)) + return values + + @property + def _resolved_values_keys_as_propnames(self): + values = [] + for k, v in self._resolved_values: + if isinstance(k, attributes.QueryableAttribute): + values.append((k.key, v)) + continue + elif hasattr(k, '__clause_element__'): + k = k.__clause_element__() + + if self.mapper and isinstance(k, expression.ColumnElement): + try: + attr = self.mapper._columntoproperty[k] + except orm_exc.UnmappedColumnError: + pass + else: + values.append((attr.key, v)) + else: + raise sa_exc.InvalidRequestError( + "Invalid expression type: %r" % k) + return values def _do_exec(self): + values = self._resolved_values - values = [ - (self._resolve_string_to_expr(k), v) - for k, v in ( - self.values.items() if hasattr(self.values, 'items') - else self.values) - ] if not self.update_kwargs.get('preserve_parameter_order', False): values = dict(values) @@ -1329,10 +1343,7 @@ class BulkUpdate(BulkUD): self.context.whereclause, values, **self.update_kwargs) - self.result = self.query.session.execute( - update_stmt, params=self.query._params, - mapper=self.mapper) - self.rowcount = self.result.rowcount + self._execute_stmt(update_stmt) def _do_post(self): session = self.query.session @@ -1357,11 +1368,7 @@ class BulkDelete(BulkUD): delete_stmt = sql.delete(self.primary_table, self.context.whereclause) - self.result = self.query.session.execute( - delete_stmt, - params=self.query._params, - mapper=self.mapper) - self.rowcount = self.result.rowcount + self._execute_stmt(delete_stmt) def _do_post(self): session = self.query.session @@ -1374,13 +1381,10 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): def _additional_evaluators(self, evaluator_compiler): self.value_evaluators = {} - values = (self.values.items() if hasattr(self.values, 'items') - else self.values) + values = self._resolved_values_keys_as_propnames for key, value in values: - key = self._resolve_key_to_attrname(key) - if key is not None: - self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value)) + self.value_evaluators[key] = evaluator_compiler.process( + expression._literal_as_binds(value)) def _do_post_synchronize(self): session = self.query.session @@ -1396,6 +1400,9 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): for key in to_evaluate: dict_[key] = self.value_evaluators[key](obj) + state.manager.dispatch.refresh( + state, None, to_evaluate) + state._commit(dict_, list(to_evaluate)) # expire attributes with pending changes @@ -1434,9 +1441,13 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): ] if identity_key in session.identity_map ]) - attrib = [_attr_as_key(k) for k in self.values] + + values = self._resolved_values_keys_as_propnames + attrib = set(k for k, v in values) for state in states: - session._expire_state(state, attrib) + to_expire = attrib.intersection(state.dict) + if to_expire: + session._expire_state(state, to_expire) session._register_altered(states) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index c8525f2f6..0244f18a9 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -330,6 +330,10 @@ class AssertsCompiledSQL(object): context = clause._compile_context() context.statement.use_labels = True clause = context.statement + elif isinstance(clause, orm.persistence.BulkUD): + with mock.patch.object(clause, "_execute_stmt") as stmt_mock: + clause.exec_() + clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: kw['compile_kwargs'] = compile_kwargs |