diff options
Diffstat (limited to 'lib/sqlalchemy/ext/sqlsoup.py')
-rw-r--r-- | lib/sqlalchemy/ext/sqlsoup.py | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 592878acd..b2790c56e 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -307,7 +307,10 @@ default schema is used. from sqlalchemy import Table, MetaData, join from sqlalchemy import schema, sql from sqlalchemy.engine.base import Engine -from sqlalchemy.orm import scoped_session, sessionmaker, mapper, class_mapper, relation, session +from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \ + class_mapper, relation, session,\ + object_session +from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE from sqlalchemy.exceptions import SQLAlchemyError, InvalidRequestError, ArgumentError from sqlalchemy.sql import expression @@ -316,6 +319,30 @@ __all__ = ['PKNotFoundError', 'SqlSoup'] Session = scoped_session(sessionmaker(autoflush=True, autocommit=False)) +class AutoAdd(MapperExtension): + def __init__(self, scoped_session): + self.scoped_session = scoped_session + + def instrument_class(self, mapper, class_): + class_.__init__ = self._default__init__(mapper) + + def _default__init__(ext, mapper): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + return __init__ + + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + session = self.scoped_session() + session._save_without_cascade(instance) + return EXT_CONTINUE + + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + sess = object_session(instance) + if sess: + sess.expunge(instance) + return EXT_CONTINUE + class PKNotFoundError(SQLAlchemyError): pass @@ -395,19 +422,19 @@ def class_for_table(selectable, **mapper_kwargs): L = ["%s=%r" % (key, getattr(self, key, '')) for key in self.__class__.c.keys()] return '%s(%s)' % (self.__class__.__name__, ','.join(L)) - + for m in ['__cmp__', '__repr__']: setattr(klass, m, eval(m)) klass._table = selectable klass.c = expression.ColumnCollection() mappr = mapper(klass, selectable, - extension=Session.extension, + extension=AutoAdd(Session), **mapper_kwargs) for k in mappr.iterate_properties: klass.c[k.key] = k.columns[0] - + klass._query = Session.query_property() return klass |