diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 127 |
1 files changed, 91 insertions, 36 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index b9e9896a8..87cbdaf0c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -249,23 +249,29 @@ def case(whens, value=None, else_=None): for c in cc.clauses: c.parens = False return cc - + def cast(clause, totype, **kwargs): - """return CAST function CAST(clause AS totype) - - Use with a sqlalchemy.types.TypeEngine object, i.e - ``cast(table.c.unit_price * table.c.qty, Numeric(10,4))`` - or ``cast(table.c.timestamp, DATE)`` - + """Return ``CAST`` function. + + Equivalent of SQL ``CAST(clause AS totype)``. + + Use with a ``sqlalchemy.types.TypeEngine`` object, i.e:: + + cast(table.c.unit_price * table.c.qty, Numeric(10,4)) + + or:: + + cast(table.c.timestamp, DATE) """ + return _Cast(clause, totype, **kwargs) def extract(field, expr): - """return extract(field FROM expr)""" + """Return ``extract(field FROM expr)``.""" + expr = _BinaryClause(text(field), expr, "FROM") return func.extract(expr) - def exists(*args, **kwargs): return _Exists(*args, **kwargs) @@ -1030,12 +1036,12 @@ class ColumnElement(Selectable, _CompareMixin): with Selectable objects. """) - def _one_fkey(self): if len(self._foreign_keys): return list(self._foreign_keys)[0] else: return None + foreign_key = property(_one_fkey) def _get_orig_set(self): @@ -1049,6 +1055,7 @@ class ColumnElement(Selectable, _CompareMixin): if len(s) == 0: s.add(self) self.__orig_set = s + orig_set = property(_get_orig_set, _set_orig_set, doc=\ """A Set containing TableClause-bound, non-proxied ColumnElements @@ -1058,7 +1065,9 @@ class ColumnElement(Selectable, _CompareMixin): """) def shares_lineage(self, othercolumn): - """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``.""" + """Return True if the given ``ColumnElement`` has a common + ancestor to this ``ColumnElement``. + """ for c in self.orig_set: if c in othercolumn.orig_set: @@ -1398,15 +1407,18 @@ class _TextClause(ClauseElement): if typemap is not None: for key in typemap.keys(): typemap[key] = sqltypes.to_instance(typemap[key]) + def repl(m): self.bindparams[m.group(1)] = bindparam(m.group(1)) return ":%s" % m.group(1) - # scan the string and search for bind parameter names, add them + + # scan the string and search for bind parameter names, add them # to the list of bindparams self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text) if bindparams is not None: for b in bindparams: self.bindparams[b.key] = b + columns = property(lambda s:[]) def get_children(self, **kwargs): @@ -1417,49 +1429,67 @@ class _TextClause(ClauseElement): def _get_from_objects(self): return [] + def supports_execution(self): return True - + class _Null(ColumnElement): - """represents the NULL keyword in a SQL statement. public contstructor is the - null() function.""" + """Represent the NULL keyword in a SQL statement. + + Public constructor is the ``null()`` function. + """ + def __init__(self): self.type = sqltypes.NULLTYPE + def accept_visitor(self, visitor): visitor.visit_null(self) + def _get_from_objects(self): return [] class ClauseList(ClauseElement): - """describes a list of clauses. by default, is comma-separated, - such as a column listing.""" + """Describe a list of clauses. + + By default, is comma-separated, such as a column listing. + """ + def __init__(self, *clauses, **kwargs): self.clauses = [] for c in clauses: if c is None: continue self.append(c) self.parens = kwargs.get('parens', False) + def __iter__(self): return iter(self.clauses) + def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return ClauseList(parens=self.parens, *clauses) + def append(self, clause): if _is_literal(clause): clause = _TextClause(str(clause)) self.clauses.append(clause) + def get_children(self, **kwargs): return self.clauses + def accept_visitor(self, visitor): visitor.visit_clauselist(self) + def _get_from_objects(self): f = [] for c in self.clauses: f += c._get_from_objects() return f + def compare(self, other): - """compares this ClauseList to the given ClauseList, including - a comparison of all the clause items.""" + """Compare this ``ClauseList`` to the given ``ClauseList``, + including a comparison of all the clause items. + """ + if isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses): for i in range(0, len(self.clauses)): if not self.clauses[i].compare(other.clauses[i]): @@ -1642,36 +1672,49 @@ class _BinaryClause(ClauseElement): self.left.parens = True if isinstance(self.right, _BinaryClause) or hasattr(self.right, '_selectable'): self.right.parens = True + def copy_container(self): return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) + def _get_from_objects(self): return self.left._get_from_objects() + self.right._get_from_objects() + def get_children(self, **kwargs): return self.left, self.right + def accept_visitor(self, visitor): visitor.visit_binary(self) + def swap(self): c = self.left self.left = self.right self.right = c + def compare(self, other): - """compares this _BinaryClause against the given _BinaryClause.""" + """Compare this ``_BinaryClause`` against the given ``_BinaryClause``.""" + return ( - isinstance(other, _BinaryClause) and self.operator == other.operator and + isinstance(other, _BinaryClause) and self.operator == other.operator and self.left.compare(other.left) and self.right.compare(other.right) ) class _BinaryExpression(_BinaryClause, ColumnElement): - """represents a binary expression, which can be in a WHERE criterion or in the column list - of a SELECT. By adding "ColumnElement" to its inherited list, it becomes a Selectable - unit which can be placed in the column list of a SELECT.""" + """Represent a binary expression, which can be in a ``WHERE`` + criterion or in the column list of a ``SELECT``. + + By adding ``ColumnElement`` to its inherited list, it becomes a + ``Selectable`` unit which can be placed in the column list of a + ``SELECT``.""" + pass class _BooleanExpression(_BinaryExpression): - """represents a boolean expression.""" + """Represent a boolean expression.""" + def __init__(self, *args, **kwargs): self.negate = kwargs.pop('negate', None) super(_BooleanExpression, self).__init__(*args, **kwargs) + def _negate(self): if self.negate is not None: return _BooleanExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) @@ -1740,9 +1783,14 @@ class Join(FromClause): constraints.add(fk.constraint) self.foreignkey = fk.parent if len(crit) == 0: - raise exceptions.ArgumentError("Cant find any foreign key relationships between '%s' and '%s'" % (primary.name, secondary.name)) + raise exceptions.ArgumentError("Can't find any foreign key relationships " + "between '%s' and '%s'" % (primary.name, secondary.name)) elif len(constraints) > 1: - raise exceptions.ArgumentError("Cant determine join between '%s' and '%s'; tables have more than one foreign key constraint relationship between them. Please specify the 'onclause' of this join explicitly." % (primary.name, secondary.name)) + raise exceptions.ArgumentError("Can't determine join between '%s' and '%s'; " + "tables have more than one foreign key " + "constraint relationship between them. " + "Please specify the 'onclause' of this " + "join explicitly." % (primary.name, secondary.name)) elif len(crit) == 1: return (crit[0]) else: @@ -1808,6 +1856,7 @@ class Join(FromClause): def get_children(self, **kwargs): return self.left, self.right, self.onclause + def accept_visitor(self, visitor): visitor.visit_join(self) @@ -1894,6 +1943,7 @@ class _Label(ColumnElement): def get_children(self, **kwargs): return self.obj, + def accept_visitor(self, visitor): visitor.visit_label(self) @@ -1925,11 +1975,12 @@ class _ColumnClause(ColumnElement): self.is_literal = is_literal def _get_label(self): - """generate a 'label' for this column. + """Generate a 'label' for this column. - the label is a product of the parent table name and column name, and - is treated as a unique identifier of this Column across all Tables and derived - selectables for a particular metadata collection. + The label is a product of the parent table name and column + name, and is treated as a unique identifier of this ``Column`` + across all ``Tables`` and derived selectables for a particular + metadata collection. """ # for a "literal" column, we've no idea what the text is @@ -2282,26 +2333,28 @@ class Select(_SelectBaseMixin, FromClause): if self.is_scalar and not hasattr(self, 'type'): self.type = column.type - # if the column is a Select statement itself, + # if the column is a Select statement itself, # accept visitor self.__correlator.traverse(column) - + # visit the FROM objects of the column looking for more Selects for f in column._get_from_objects(): if f is not self: self.__correlator.traverse(f) self._process_froms(column, False) + def _make_proxy(self, selectable, name): if self.is_scalar: return self._raw_columns[0]._make_proxy(selectable, name) else: raise exceptions.InvalidRequestError("Not a scalar select statement") + def label(self, name): if not self.is_scalar: raise exceptions.InvalidRequestError("Not a scalar select statement") else: return label(name, self) - + def _exportable_columns(self): return [c for c in self._raw_columns if isinstance(c, Selectable)] @@ -2310,7 +2363,7 @@ class Select(_SelectBaseMixin, FromClause): return column._make_proxy(self, name=column._label) else: return column._make_proxy(self) - + def _process_froms(self, elem, asfrom): for f in elem._get_from_objects(): self.__froms.add(f) @@ -2369,7 +2422,9 @@ class Select(_SelectBaseMixin, FromClause): else: return f - froms = property(_calc_froms, doc="""A collection containing all elements of the FROM clause""") + froms = property(_calc_froms, + doc="""A collection containing all elements + of the ``FROM`` clause.""") def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.columns) or []) + \ |