diff options
-rw-r--r-- | CHANGES | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 65 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 33 | ||||
-rw-r--r-- | test/aaa_profiling/test_orm.py | 85 | ||||
-rw-r--r-- | test/orm/test_merge.py | 40 | ||||
-rw-r--r-- | test/orm/test_query.py | 1 |
6 files changed, 190 insertions, 39 deletions
@@ -121,6 +121,11 @@ CHANGES - the "dont_load=True" flag on Session.merge() is deprecated and is now "load=False". + + - Session.merge() is performance optimized, using half the + call counts for "load=False" mode compared to 0.5 and + significantly fewer SQL queries in the case of collections + for "load=True" mode. - `expression.null()` is fully understood the same way None is when comparing an object/collection-referencing diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 1bb850488..e63c2b867 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -105,13 +105,17 @@ class ColumnProperty(StrategizedProperty): def setattr(self, state, value, column): state.get_impl(self.key).set(state, state.dict, value, None) - def merge(self, session, source, dest, load, _recursive): - value = attributes.instance_state(source).value_as_iterable( - self.key, passive=True) - if value: - setattr(dest, self.key, value[0]) + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): + if self.key in source_dict: + value = source_dict[self.key] + + if not load: + dest_dict[self.key] = value + else: + impl = dest_state.get_impl(self.key) + impl.set(dest_state, dest_dict, value, None) else: - attributes.instance_state(dest).expire_attributes([self.key]) + dest_state.expire_attributes([self.key]) def get_col_value(self, column, value): return value @@ -301,7 +305,7 @@ class SynonymProperty(MapperProperty): proxy_property=self.descriptor ) - def merge(self, session, source, dest, load, _recursive): + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): pass log.class_logger(SynonymProperty) @@ -334,7 +338,7 @@ class ComparableProperty(MapperProperty): def create_row_processor(self, selectcontext, path, mapper, row, adapter): return (None, None) - def merge(self, session, source, dest, load, _recursive): + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): pass @@ -624,50 +628,61 @@ class RelationProperty(StrategizedProperty): def __str__(self): return str(self.parent.class_.__name__) + "." + self.key - def merge(self, session, source, dest, load, _recursive): + def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): if load: # TODO: no test coverage for recursive check for r in self._reverse_property: - if (source, r) in _recursive: + if (source_state, r) in _recursive: return - source_state = attributes.instance_state(source) - dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest) - if not "merge" in self.cascade: dest_state.expire_attributes([self.key]) return - instances = source_state.value_as_iterable(self.key, passive=True) - - if not instances: + if self.key not in source_dict: return if self.uselist: + instances = source_state.get_impl(self.key).\ + get(source_state, source_dict) + + if load: + # for a full merge, pre-load the destination collection, + # so that individual _merge of each item pulls from identity + # map for those already present. + # also assumes CollectionAttrbiuteImpl behavior of loading + # "old" list in any case + dest_state.get_impl(self.key).get(dest_state, dest_dict) + dest_list = [] for current in instances: - _recursive[(current, self)] = True - obj = session._merge(current, load=load, _recursive=_recursive) + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive) if obj is not None: dest_list.append(obj) + if not load: coll = attributes.init_state_collection(dest_state, dest_dict, self.key) for c in dest_list: coll.append_without_event(c) else: - getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list) + dest_state.get_impl(self.key)._set_iterable(dest_state, dest_dict, dest_list) else: - current = instances[0] + current = source_dict[self.key] if current is not None: - _recursive[(current, self)] = True - obj = session._merge(current, load=load, _recursive=_recursive) + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive) else: obj = None - + if not load: - dest_state.dict[self.key] = obj + dest_dict[self.key] = obj else: - setattr(dest, self.key, obj) + dest_state.get_impl(self.key).set(dest_state, dest_dict, obj, None) def cascade_iterator(self, type_, state, visited_instances, halt_on=None): if not type_ in self.cascade: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index f171622d4..ee4286c67 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1103,24 +1103,25 @@ class Session(object): load = not kw['dont_load'] util.warn_deprecated("dont_load=True has been renamed to load=False.") - # TODO: this should be an IdentityDict for instances, but will - # need a separate dict for PropertyLoader tuples _recursive = {} self._autoflush() + _object_mapper(instance) # verify mapped autoflush = self.autoflush try: self.autoflush = False - return self._merge(instance, load=load, _recursive=_recursive) + return self._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, _recursive=_recursive) finally: self.autoflush = autoflush - def _merge(self, instance, load=True, _recursive=None): - mapper = _object_mapper(instance) - if instance in _recursive: - return _recursive[instance] + def _merge(self, state, state_dict, load=True, _recursive=None): + mapper = _state_mapper(state) + if state in _recursive: + return _recursive[state] new_instance = False - state = attributes.instance_state(instance) key = state.key if key is None: @@ -1134,6 +1135,7 @@ class Session(object): if key in self.identity_map: merged = self.identity_map[key] + elif not load: if state.modified: raise sa_exc.InvalidRequestError( @@ -1154,16 +1156,21 @@ class Session(object): if merged is None: merged = mapper.class_manager.new_instance() merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) new_instance = True - self.add(merged) - - _recursive[instance] = merged + self._save_or_update_state(merged_state) + else: + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + + _recursive[state] = merged for prop in mapper.iterate_properties: - prop.merge(self, instance, merged, load, _recursive) + prop.merge(self, state, state_dict, merged_state, merged_dict, load, _recursive) if not load: - attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map) # remove any history + # remove any history + merged_state.commit_all(merged_dict, self.identity_map) if new_instance: merged_state._run_on_load(merged) diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py new file mode 100644 index 000000000..d88e73c67 --- /dev/null +++ b/test/aaa_profiling/test_orm.py @@ -0,0 +1,85 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy import exc as sa_exc, util, Integer, String, ForeignKey +from sqlalchemy.orm import exc as orm_exc, mapper, relation, sessionmaker + +from sqlalchemy.test import testing, profiling +from test.orm import _base +from sqlalchemy.test.schema import Table, Column + + +class MergeTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + parent = Table('parent', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(20)) + ) + + child = Table('child', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(20)), + Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False) + ) + + + @classmethod + def setup_classes(cls): + class Parent(_base.BasicEntity): + pass + class Child(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Parent, parent, properties={ + 'children':relation(Child, backref='parent') + }) + mapper(Child, child) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + parent.insert().execute( + {'id':1, 'data':'p1'}, + ) + child.insert().execute( + {'id':1, 'data':'p1c1', 'parent_id':1}, + ) + + @testing.resolve_artifact_names + def test_merge_no_load(self): + sess = sessionmaker()() + sess2 = sessionmaker()() + + p1 = sess.query(Parent).get(1) + p1.children + + # down from 185 on this + # this is a small slice of a usually bigger + # operation so using a small variance + @profiling.function_call_count(106, variance=0.001) + def go(): + p2 = sess2.merge(p1, load=False) + + go() + + @testing.resolve_artifact_names + def test_merge_load(self): + sess = sessionmaker()() + sess2 = sessionmaker()() + + p1 = sess.query(Parent).get(1) + p1.children + + # preloading of collection took this down from 1728 + # to 1192 using sqlite3 + @profiling.function_call_count(1192) + def go(): + p2 = sess2.merge(p1) + go() + + # one more time, count the SQL + sess2 = sessionmaker()() + self.assert_sql_count(testing.db, go, 2) + diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 533c3ea5d..c3b28386d 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -4,7 +4,8 @@ from sqlalchemy import Integer, PickleType import operator from sqlalchemy.test import testing from sqlalchemy.util import OrderedSet -from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property, sessionmaker +from sqlalchemy.orm import mapper, relation, create_session, PropComparator, \ + synonym, comparable_property, sessionmaker, attributes from sqlalchemy.test.testing import eq_, ne_ from test.orm import _base, _fixtures from sqlalchemy.test.schema import Table, Column @@ -379,6 +380,43 @@ class MergeTest(_fixtures.FixtureTest): eq_(u3.name, 'also fred') @testing.resolve_artifact_names + def test_many_to_one_cascade(self): + mapper(Address, addresses, properties={ + 'user':relation(User) + }) + mapper(User, users) + + u1 = User(id=1, name="u1") + a1 =Address(id=1, email_address="a1", user=u1) + u2 = User(id=2, name="u2") + + sess = create_session() + sess.add_all([a1, u2]) + sess.flush() + + a1.user = u2 + + sess2 = create_session() + a2 = sess2.merge(a1) + eq_( + attributes.get_history(a2, 'user'), + ([u2], (), [attributes.PASSIVE_NO_RESULT]) + ) + assert a2 in sess2.dirty + + sess.refresh(a1) + + sess2 = create_session() + a2 = sess2.merge(a1, load=False) + eq_( + attributes.get_history(a2, 'user'), + ((), [u1], ()) + ) + assert a2 not in sess2.dirty + + + + @testing.resolve_artifact_names def test_many_to_many_cascade(self): mapper(Order, orders, properties={ diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 967d1ac6c..bc3b9e26d 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -929,6 +929,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): sess.query(User.id).from_self().\ add_column(func.count().label('foo')).\ group_by(User.id).\ + order_by(User.id).\ from_self().all(), [ (7,1), (8, 1), (9, 1), (10, 1) |