import sqlalchemy.sql as sql import sqlalchemy.orm as orm class SelectResultsExt(orm.MapperExtension): """a MapperExtension that provides SelectResults functionality for the results of query.select_by() and query.select()""" def select_by(self, query, *args, **params): return SelectResults(query, query.join_by(*args, **params)) def select(self, query, arg=None, **kwargs): if hasattr(arg, '_selectable'): return orm.EXT_PASS else: return SelectResults(query, arg, ops=kwargs) class SelectResults(object): """Builds a query one component at a time via separate method calls, each call transforming the previous SelectResults instance into a new SelectResults instance with further limiting criterion added. When interpreted in an iterator context (such as via calling list(selectresults)), executes the query.""" def __init__(self, query, clause=None, ops={}, joinpoint=None): """constructs a new SelectResults using the given Query object and optional WHERE clause. ops is an optional dictionary of bind parameter values.""" self._query = query self._clause = clause self._ops = {} self._ops.update(ops) self._joinpoint = joinpoint or (self._query.table, self._query.mapper) def count(self): """executes the SQL count() function against the SelectResults criterion.""" return self._query.count(self._clause, **self._ops) def _col_aggregate(self, col, func): """executes func() function against the given column For performance, only use subselect if order_by attribute is set. """ if self._ops.get('order_by'): s1 = sql.select([col], self._clause, **self._ops).alias('u') return sql.select([func(s1.corresponding_column(col))]).scalar() else: return sql.select([func(col)], self._clause, **self._ops).scalar() def min(self, col): """executes the SQL min() function against the given column""" return self._col_aggregate(col, sql.func.min) def max(self, col): """executes the SQL max() function against the given column""" return self._col_aggregate(col, sql.func.max) def sum(self, col): """executes the SQL sum() function against the given column""" return self._col_aggregate(col, sql.func.sum) def avg(self, col): """executes the SQL avg() function against the given column""" return self._col_aggregate(col, sql.func.avg) def clone(self): """creates a copy of this SelectResults.""" return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint) def filter(self, clause): """applies an additional WHERE clause against the query.""" new = self.clone() new._clause = sql.and_(self._clause, clause) return new def select(self, clause): return self.filter(clause) def select_by(self, *args, **kwargs): return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1])) def order_by(self, order_by): """apply an ORDER BY to the query.""" new = self.clone() new._ops['order_by'] = order_by return new def limit(self, limit): """apply a LIMIT to the query.""" return self[:limit] def offset(self, offset): """apply an OFFSET to the query.""" return self[offset:] def distinct(self): """applies a DISTINCT to the query""" new = self.clone() new._ops['distinct'] = True return new def list(self): """return the results represented by this SelectResults as a list. this results in an execution of the underlying query.""" return list(self) def select_from(self, from_obj): """set the from_obj parameter of the query to a specific table or set of tables. from_obj is a list.""" new = self.clone() new._ops['from_obj'] = from_obj return new def join_to(self, prop): """join the table of this SelectResults to the table located against the given property name. subsequent calls to join_to or outerjoin_to will join against the rightmost table located from the previous join_to or outerjoin_to call, searching for the property starting with the rightmost mapper last located.""" new = self.clone() (clause, mapper) = self._join_to(prop, outerjoin=False) new._ops['from_obj'] = [clause] new._joinpoint = (clause, mapper) return new def outerjoin_to(self, prop): """outer join the table of this SelectResults to the table located against the given property name. subsequent calls to join_to or outerjoin_to will join against the rightmost table located from the previous join_to or outerjoin_to call, searching for the property starting with the rightmost mapper last located.""" new = self.clone() (clause, mapper) = self._join_to(prop, outerjoin=True) new._ops['from_obj'] = [clause] new._joinpoint = (clause, mapper) return new def _join_to(self, prop, outerjoin=False): [keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1]) clause = self._joinpoint[0] mapper = self._joinpoint[1] for key in keys: prop = mapper.props[key] if outerjoin: clause = clause.outerjoin(prop.select_table, prop.get_join(mapper)) else: clause = clause.join(prop.select_table, prop.get_join(mapper)) mapper = prop.mapper return (clause, mapper) def compile(self): return self._query.compile(self._clause, **self._ops) def __getitem__(self, item): if isinstance(item, slice): start = item.start stop = item.stop if (isinstance(start, int) and start < 0) or \ (isinstance(stop, int) and stop < 0): return list(self)[item] else: res = self.clone() if start is not None and stop is not None: res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start)) elif start is None and stop is not None: res._ops.update(dict(limit=stop)) elif start is not None and stop is None: res._ops.update(dict(offset=self._ops.get('offset', 0)+start)) if item.step is not None: return list(res)[None:None:item.step] else: return res else: return list(self[item:item+1])[0] def __iter__(self): return iter(self._query.select_whereclause(self._clause, **self._ops))