diff options
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r-- | lib/sqlalchemy/ext/activemapper.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/assignmapper.py | 59 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 88 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/proxy.py | 113 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/selectresults.py | 218 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/sessioncontext.py | 28 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/sqlsoup.py | 13 |
7 files changed, 145 insertions, 385 deletions
diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py index 2fcf44f61..fa32a5fc3 100644 --- a/lib/sqlalchemy/ext/activemapper.py +++ b/lib/sqlalchemy/ext/activemapper.py @@ -1,11 +1,10 @@ -from sqlalchemy import create_session, relation, mapper, \ - join, ThreadLocalMetaData, class_mapper, \ - util, Integer -from sqlalchemy import and_, or_ +from sqlalchemy import ThreadLocalMetaData, util, Integer from sqlalchemy import Table, Column, ForeignKey +from sqlalchemy.orm import class_mapper, relation, create_session + from sqlalchemy.ext.sessioncontext import SessionContext from sqlalchemy.ext.assignmapper import assign_mapper -from sqlalchemy import backref as create_backref +from sqlalchemy.orm import backref as create_backref import sqlalchemy import inspect @@ -14,7 +13,7 @@ import sys # # the "proxy" to the database engine... this can be swapped out at runtime # -metadata = ThreadLocalMetaData("activemapper") +metadata = ThreadLocalMetaData() try: objectstore = sqlalchemy.objectstore diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py index 4708afd8d..238041702 100644 --- a/lib/sqlalchemy/ext/assignmapper.py +++ b/lib/sqlalchemy/ext/assignmapper.py @@ -1,51 +1,50 @@ -from sqlalchemy import mapper, util, Query, exceptions +from sqlalchemy import util, exceptions import types - -def monkeypatch_query_method(ctx, class_, name): - def do(self, *args, **kwargs): - query = Query(class_, session=ctx.current) - return getattr(query, name)(*args, **kwargs) - try: - do.__name__ = name - except: - pass - setattr(class_, name, classmethod(do)) - -def monkeypatch_objectstore_method(ctx, class_, name): +from sqlalchemy.orm import mapper + +def _monkeypatch_session_method(name, ctx, class_): def do(self, *args, **kwargs): session = ctx.current - if name == "flush": - # flush expects a list of objects - self = [self] return getattr(session, name)(self, *args, **kwargs) try: do.__name__ = name except: pass - setattr(class_, name, do) - + if not hasattr(class_, name): + setattr(class_, name, do) + def assign_mapper(ctx, class_, *args, **kwargs): + extension = kwargs.pop('extension', None) + if extension is not None: + extension = util.to_list(extension) + extension.append(ctx.mapper_extension) + else: + extension = ctx.mapper_extension + validate = kwargs.pop('validate', False) + if not isinstance(getattr(class_, '__init__'), types.MethodType): def __init__(self, **kwargs): for key, value in kwargs.items(): if validate: - if not key in self.mapper.props: + if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False): raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) setattr(self, key, value) class_.__init__ = __init__ - extension = kwargs.pop('extension', None) - if extension is not None: - extension = util.to_list(extension) - extension.append(ctx.mapper_extension) - else: - extension = ctx.mapper_extension + + class query(object): + def __getattr__(self, key): + return getattr(ctx.current.query(class_), key) + def __call__(self): + return ctx.current.query(class_) + + if not hasattr(class_, 'query'): + class_.query = query() + + for name in ['refresh', 'expire', 'delete', 'expunge', 'update']: + _monkeypatch_session_method(name, ctx, class_) + m = mapper(class_, extension=extension, *args, **kwargs) class_.mapper = m - class_.query = classmethod(lambda cls: Query(class_, session=ctx.current)) - for name in ['get', 'filter', 'filter_by', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']: - monkeypatch_query_method(ctx, class_, name) - for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']: - monkeypatch_objectstore_method(ctx, class_, name) return m diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index cdb814702..2dd807222 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -6,11 +6,10 @@ transparent proxied access to the endpoint of an association object. See the example ``examples/association/proxied_association.py``. """ -from sqlalchemy.orm.attributes import InstrumentedList +import weakref, itertools import sqlalchemy.exceptions as exceptions import sqlalchemy.orm as orm import sqlalchemy.util as util -import weakref def association_proxy(targetcollection, attr, **kw): """Convenience function for use in mapped classes. Implements a Python @@ -109,7 +108,7 @@ class AssociationProxy(object): self.collection_class = None def _get_property(self): - return orm.class_mapper(self.owning_class).props[self.target_collection] + return orm.class_mapper(self.owning_class).get_property(self.target_collection) def _target_class(self): return self._get_property().mapper.class_ @@ -168,15 +167,7 @@ class AssociationProxy(object): def _new(self, lazy_collection): creator = self.creator and self.creator or self.target_class - - # Prefer class typing here to spot dicts with the required append() - # method. - collection = lazy_collection() - if isinstance(collection.data, dict): - self.collection_class = dict - else: - self.collection_class = util.duck_type_collection(collection.data) - del collection + self.collection_class = util.duck_type_collection(lazy_collection()) if self.proxy_factory: return self.proxy_factory(lazy_collection, creator, self.value_attr) @@ -269,7 +260,33 @@ class _AssociationList(object): return self._get(self.col[index]) def __setitem__(self, index, value): - self._set(self.col[index], value) + if not isinstance(index, slice): + self._set(self.col[index], value) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + rng = range(index.start or 0, stop, step) + if step == 1: + for i in rng: + del self[index.start] + i = index.start + for item in value: + self.insert(i, item) + i += 1 + else: + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), + len(rng))) + for i, item in zip(rng, value): + self._set(self.col[i], item) def __delitem__(self, index): del self.col[index] @@ -291,9 +308,13 @@ class _AssociationList(object): del self.col[start:end] def __iter__(self): - """Iterate over proxied values. For the actual domain objects, - iterate over .col instead or just use the underlying collection - directly from its property on the parent.""" + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + for member in self.col: yield self._get(member) raise StopIteration @@ -304,6 +325,10 @@ class _AssociationList(object): item = self._create(value, **kw) self.col.append(item) + def count(self, value): + return sum([1 for _ in + itertools.ifilter(lambda v: v == value, iter(self))]) + def extend(self, values): for v in values: self.append(v) @@ -311,6 +336,26 @@ class _AssociationList(object): def insert(self, index, value): self.col[index:index] = [self._create(value)] + def pop(self, index=-1): + return self.getter(self.col.pop(index)) + + def remove(self, value): + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self): + """Not supported, use reversed(mylist)""" + + raise NotImplementedError + + def sort(self): + """Not supported, use sorted(mylist)""" + + raise NotImplementedError + def clear(self): del self.col[0:len(self.col)] @@ -545,9 +590,7 @@ class _AssociationSet(object): def add(self, value): if value not in self: - # must shove this through InstrumentedList.append() which will - # eventually call the collection_class .add() - self.col.append(self._create(value)) + self.col.add(self._create(value)) # for discard and remove, choosing a more expensive check strategy rather # than call self.creator() @@ -567,12 +610,7 @@ class _AssociationSet(object): def pop(self): if not self.col: raise KeyError('pop from an empty set') - # grumble, pop() is borked on InstrumentedList (#548) - if isinstance(self.col, InstrumentedList): - member = list(self.col)[0] - self.col.remove(member) - else: - member = self.col.pop() + member = self.col.pop() return self._get(member) def update(self, other): diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py deleted file mode 100644 index b81702fc4..000000000 --- a/lib/sqlalchemy/ext/proxy.py +++ /dev/null @@ -1,113 +0,0 @@ -try: - from threading import local -except ImportError: - from sqlalchemy.util import ThreadLocal as local - -from sqlalchemy import sql -from sqlalchemy.engine import create_engine, Engine - -__all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine'] - -class BaseProxyEngine(sql.Executor): - """Basis for all proxy engines.""" - - def get_engine(self): - raise NotImplementedError - - def set_engine(self, engine): - raise NotImplementedError - - engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e)) - - def execute_compiled(self, *args, **kwargs): - """Override superclass behaviour. - - This method is required to be present as it overrides the - `execute_compiled` present in ``sql.Engine``. - """ - - return self.get_engine().execute_compiled(*args, **kwargs) - - def compiler(self, *args, **kwargs): - """Override superclass behaviour. - - This method is required to be present as it overrides the - `compiler` method present in ``sql.Engine``. - """ - - return self.get_engine().compiler(*args, **kwargs) - - def __getattr__(self, attr): - """Provide proxying for methods that are not otherwise present on this ``BaseProxyEngine``. - - Note that methods which are present on the base class - ``sql.Engine`` will **not** be proxied through this, and must - be explicit on this class. - """ - - # call get_engine() to give subclasses a chance to change - # connection establishment behavior - e = self.get_engine() - if e is not None: - return getattr(e, attr) - raise AttributeError("No connection established in ProxyEngine: " - " no access to %s" % attr) - -class AutoConnectEngine(BaseProxyEngine): - """An SQLEngine proxy that automatically connects when necessary.""" - - def __init__(self, dburi, **kwargs): - BaseProxyEngine.__init__(self) - self.dburi = dburi - self.kwargs = kwargs - self._engine = None - - def get_engine(self): - if self._engine is None: - if callable(self.dburi): - dburi = self.dburi() - else: - dburi = self.dburi - self._engine = create_engine(dburi, **self.kwargs) - return self._engine - - -class ProxyEngine(BaseProxyEngine): - """Engine proxy for lazy and late initialization. - - This engine will delegate access to a real engine set with connect(). - """ - - def __init__(self, **kwargs): - BaseProxyEngine.__init__(self) - # create the local storage for uri->engine map and current engine - self.storage = local() - self.kwargs = kwargs - - def connect(self, *args, **kwargs): - """Establish connection to a real engine.""" - - kwargs.update(self.kwargs) - if not kwargs: - key = repr(args) - else: - key = "%s, %s" % (repr(args), repr(sorted(kwargs.items()))) - try: - map = self.storage.connection - except AttributeError: - self.storage.connection = {} - self.storage.engine = None - map = self.storage.connection - try: - self.storage.engine = map[key] - except KeyError: - map[key] = create_engine(*args, **kwargs) - self.storage.engine = map[key] - - def get_engine(self): - if not hasattr(self.storage, 'engine') or self.storage.engine is None: - raise AttributeError("No connection established") - return self.storage.engine - - def set_engine(self, engine): - self.storage.engine = engine diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py index 68538f3cb..1920b6f92 100644 --- a/lib/sqlalchemy/ext/selectresults.py +++ b/lib/sqlalchemy/ext/selectresults.py @@ -1,212 +1,28 @@ +"""SelectResults has been rolled into Query. This class is now just a placeholder.""" + import sqlalchemy.sql as sql import sqlalchemy.orm as orm class SelectResultsExt(orm.MapperExtension): """a MapperExtension that provides SelectResults functionality for the results of query.select_by() and query.select()""" + def select_by(self, query, *args, **params): - return SelectResults(query, query.join_by(*args, **params)) + q = query + for a in args: + q = q.filter(a) + return q.filter_by(**params) + def select(self, query, arg=None, **kwargs): if isinstance(arg, sql.FromClause) and arg.supports_execution(): return orm.EXT_PASS else: - return SelectResults(query, arg, ops=kwargs) - -class SelectResults(object): - """Build a query one component at a time via separate method - calls, each call transforming the previous ``SelectResults`` - instance into a new ``SelectResults`` instance with further - limiting criterion added. When interpreted in an iterator context - (such as via calling ``list(selectresults)``), executes the query. - """ - - def __init__(self, query, clause=None, ops={}, joinpoint=None): - """Construct a new ``SelectResults`` using the given ``Query`` - object and optional ``WHERE`` clause. `ops` is an optional - dictionary of bind parameter values. - """ - - self._query = query - self._clause = clause - self._ops = {} - self._ops.update(ops) - self._joinpoint = joinpoint or (self._query.table, self._query.mapper) - - def options(self,*args, **kwargs): - """Apply mapper options to the underlying query. - - See also ``Query.options``. - """ - - new = self.clone() - new._query = new._query.options(*args, **kwargs) - return new - - def count(self): - """Execute the SQL ``count()`` function against the ``SelectResults`` criterion.""" - - return self._query.count(self._clause, **self._ops) - - def _col_aggregate(self, col, func): - """Execute ``func()`` function against the given column. - - For performance, only use subselect if `order_by` attribute is set. - """ - - if self._ops.get('order_by'): - s1 = sql.select([col], self._clause, **self._ops).alias('u') - return sql.select([func(s1.corresponding_column(col))]).scalar() - else: - return sql.select([func(col)], self._clause, **self._ops).scalar() - - def min(self, col): - """Execute the SQL ``min()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.min) - - def max(self, col): - """Execute the SQL ``max()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.max) - - def sum(self, col): - """Execute the SQL ``sum()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.sum) - - def avg(self, col): - """Execute the SQL ``avg()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.avg) - - def clone(self): - """Create a copy of this ``SelectResults``.""" - - return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint) - - def filter(self, clause): - """Apply an additional ``WHERE`` clause against the query.""" - - new = self.clone() - new._clause = sql.and_(self._clause, clause) - return new - - def select(self, clause): - return self.filter(clause) - - def select_by(self, *args, **kwargs): - return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1])) - - def order_by(self, order_by): - """Apply an ``ORDER BY`` to the query.""" - - new = self.clone() - new._ops['order_by'] = order_by - return new - - def limit(self, limit): - """Apply a ``LIMIT`` to the query.""" - - return self[:limit] - - def offset(self, offset): - """Apply an ``OFFSET`` to the query.""" - - return self[offset:] - - def distinct(self): - """Apply a ``DISTINCT`` to the query.""" - - new = self.clone() - new._ops['distinct'] = True - return new - - def list(self): - """Return the results represented by this ``SelectResults`` as a list. - - This results in an execution of the underlying query. - """ - - return list(self) - - def select_from(self, from_obj): - """Set the `from_obj` parameter of the query. - - `from_obj` is a list of one or more tables. - """ - - new = self.clone() - new._ops['from_obj'] = from_obj - return new - - def join_to(self, prop): - """Join the table of this ``SelectResults`` to the table located against the given property name. - - Subsequent calls to join_to or outerjoin_to will join against - the rightmost table located from the previous `join_to` or - `outerjoin_to` call, searching for the property starting with - the rightmost mapper last located. - """ - - new = self.clone() - (clause, mapper) = self._join_to(prop, outerjoin=False) - new._ops['from_obj'] = [clause] - new._joinpoint = (clause, mapper) - return new - - def outerjoin_to(self, prop): - """Outer join the table of this ``SelectResults`` to the - table located against the given property name. - - Subsequent calls to join_to or outerjoin_to will join against - the rightmost table located from the previous ``join_to`` or - ``outerjoin_to`` call, searching for the property starting with - the rightmost mapper last located. - """ - - new = self.clone() - (clause, mapper) = self._join_to(prop, outerjoin=True) - new._ops['from_obj'] = [clause] - new._joinpoint = (clause, mapper) - return new - - def _join_to(self, prop, outerjoin=False): - [keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1]) - clause = self._joinpoint[0] - mapper = self._joinpoint[1] - for key in keys: - prop = mapper.props[key] - if outerjoin: - clause = clause.outerjoin(prop.select_table, prop.get_join(mapper)) - else: - clause = clause.join(prop.select_table, prop.get_join(mapper)) - mapper = prop.mapper - return (clause, mapper) - - def compile(self): - return self._query.compile(self._clause, **self._ops) - - def __getitem__(self, item): - if isinstance(item, slice): - start = item.start - stop = item.stop - if (isinstance(start, int) and start < 0) or \ - (isinstance(stop, int) and stop < 0): - return list(self)[item] - else: - res = self.clone() - if start is not None and stop is not None: - res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start)) - elif start is None and stop is not None: - res._ops.update(dict(limit=stop)) - elif start is not None and stop is None: - res._ops.update(dict(offset=self._ops.get('offset', 0)+start)) - if item.step is not None: - return list(res)[None:None:item.step] - else: - return res - else: - return list(self[item:item+1])[0] - - def __iter__(self): - return iter(self._query.select_whereclause(self._clause, **self._ops)) + if arg is not None: + query = query.filter(arg) + return query._legacy_select_kwargs(**kwargs) + +def SelectResults(query, clause=None, ops={}): + if clause is not None: + query = query.filter(clause) + query = query.options(orm.extension(SelectResultsExt())) + return query._legacy_select_kwargs(**ops) diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py index 2f81e55d2..fcbf29c3f 100644 --- a/lib/sqlalchemy/ext/sessioncontext.py +++ b/lib/sqlalchemy/ext/sessioncontext.py @@ -1,5 +1,5 @@ from sqlalchemy.util import ScopedRegistry -from sqlalchemy.orm.mapper import MapperExtension +from sqlalchemy.orm import create_session, object_session, MapperExtension, EXT_PASS __all__ = ['SessionContext', 'SessionContextExt'] @@ -15,16 +15,18 @@ class SessionContext(object): engine = create_engine(...) def session_factory(): - return Session(bind_to=engine) + return Session(bind=engine) context = SessionContext(session_factory) s = context.current # get thread-local session - context.current = Session(bind_to=other_engine) # set current session + context.current = Session(bind=other_engine) # set current session del context.current # discard the thread-local session (a new one will # be created on the next call to context.current) """ - def __init__(self, session_factory, scopefunc=None): + def __init__(self, session_factory=None, scopefunc=None): + if session_factory is None: + session_factory = create_session self.registry = ScopedRegistry(session_factory, scopefunc) super(SessionContext, self).__init__() @@ -60,3 +62,21 @@ class SessionContextExt(MapperExtension): def get_session(self): return self.context.current + + def init_instance(self, mapper, class_, instance, args, kwargs): + session = kwargs.pop('_sa_session', self.context.current) + session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None)) + return EXT_PASS + + def init_failed(self, mapper, class_, instance, args, kwargs): + object_session(instance).expunge(instance) + return EXT_PASS + + def dispose_class(self, mapper, class_): + if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): + if class_.__init__._oldinit is not None: + class_.__init__ = class_.__init__._oldinit + else: + delattr(class_, '__init__') + + diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 04e5b49f7..756b5e1e7 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -310,8 +310,8 @@ Boring tests here. Nothing of real expository value. """ from sqlalchemy import * +from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext -from sqlalchemy.ext.assignmapper import assign_mapper from sqlalchemy.exceptions import * @@ -392,7 +392,7 @@ class SelectableClassType(type): def update(cls, whereclause=None, values=None, **kwargs): _ddl_error(cls) - def _selectable(cls): + def __selectable__(cls): return cls._table def __getattr__(cls, attr): @@ -434,9 +434,7 @@ def _selectable_name(selectable): return x def class_for_table(selectable, **mapper_kwargs): - if not hasattr(selectable, '_selectable') \ - or selectable._selectable() != selectable: - raise ArgumentError('class_for_table requires a selectable as its argument') + selectable = sql._selectable(selectable) mapname = 'Mapped' + _selectable_name(selectable) if isinstance(selectable, Table): klass = TableClassType(mapname, (object,), {}) @@ -520,7 +518,7 @@ class SqlSoup: def with_labels(self, item): # TODO give meaningful aliases - return self.map(item._selectable().select(use_labels=True).alias('foo')) + return self.map(sql._selectable(item).select(use_labels=True).alias('foo')) def join(self, *args, **kwargs): j = join(*args, **kwargs) @@ -539,6 +537,9 @@ class SqlSoup: t = None self._cache[attr] = t return t + + def __repr__(self): + return 'SqlSoup(%r)' % self._metadata if __name__ == '__main__': import doctest |