diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
commit | bb79e2e871d0a4585164c1a6ed626d96d0231975 (patch) | |
tree | 6d457ba6c36c408b45db24ec3c29e147fe7504ff /lib/sqlalchemy/sql.py | |
parent | 4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff) | |
download | sqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz |
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 366 |
1 files changed, 212 insertions, 154 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 38866184f..d1d1d837e 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -5,11 +5,9 @@ """defines the base components of SQL expression trees.""" -import schema -import util -import types as sqltypes -from exceptions import * -import string, re, random +from sqlalchemy import util, exceptions +from sqlalchemy import types as sqltypes +import string, re, random, sets types = __import__('types') __all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] @@ -220,8 +218,7 @@ def text(text, engine=None, *args, **kwargs): text - the text of the SQL statement to be created. use :<param> to specify bind parameters; they will be compiled to their engine-specific format. - engine - an optional engine to be used for this text query. Alternatively, call the - text() method off the engine directly. + engine - an optional engine to be used for this text query. bindparams - a list of bindparam() instances which can be used to define the types and/or initial values for the bind parameters within the textual statement; @@ -257,28 +254,33 @@ def _is_literal(element): def is_column(col): return isinstance(col, ColumnElement) -class AbstractEngine(object): - """represents a 'thing that can produce Compiler objects an execute them'.""" +class Engine(object): + """represents a 'thing that can produce Compiled objects and execute them'.""" def execute_compiled(self, compiled, parameters, echo=None, **kwargs): raise NotImplementedError() def compiler(self, statement, parameters, **kwargs): raise NotImplementedError() +class AbstractDialect(object): + """represents the behavior of a particular database. Used by Compiled objects.""" + pass + class ClauseParameters(util.OrderedDict): """represents a dictionary/iterator of bind parameter key names/values. Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method. Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects. The non-converted value can be retrieved via the get_original method. For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order.""" - def __init__(self, engine=None): + def __init__(self, dialect): super(ClauseParameters, self).__init__(self) - self.engine = engine + self.dialect=dialect self.binds = {} def set_parameter(self, key, value, bindparam): self[key] = value self.binds[key] = bindparam def get_original(self, key): + """returns the given parameter as it was originally placed in this ClauseParameters object, without any Type conversion""" return super(ClauseParameters, self).__getitem__(key) def __getitem__(self, key): v = super(ClauseParameters, self).__getitem__(key) - if self.engine is not None and self.binds.has_key(key): - v = self.binds[key].typeprocess(v, self.engine) + if self.binds.has_key(key): + v = self.binds[key].typeprocess(v, self.dialect) return v def values(self): return [self[key] for key in self] @@ -318,7 +320,7 @@ class Compiled(ClauseVisitor): object be dependent on the actual values of those bind parameters, even though it may reference those values as defaults.""" - def __init__(self, statement, parameters, engine=None): + def __init__(self, dialect, statement, parameters, engine=None): """constructs a new Compiled object. statement - ClauseElement to be compiled @@ -332,11 +334,12 @@ class Compiled(ClauseVisitor): clauses of an UPDATE statement. The keys of the parameter dictionary can either be the string names of columns or ColumnClause objects. - engine - optional SQLEngine to compile this statement against""" - self.parameters = parameters + engine - optional Engine to compile this statement against""" + self.dialect = dialect self.statement = statement + self.parameters = parameters self.engine = engine - + def __str__(self): """returns the string text of the generated SQL statement.""" raise NotImplementedError() @@ -357,13 +360,10 @@ class Compiled(ClauseVisitor): def execute(self, *multiparams, **params): """executes this compiled object using the AbstractEngine it is bound to.""" - if len(multiparams): - params = multiparams - e = self.engine if e is None: - raise InvalidRequestError("This Compiled object is not bound to any engine.") - return e.execute_compiled(self, params) + raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.") + return e.execute_compiled(self, *multiparams, **params) def scalar(self, *multiparams, **params): """executes this compiled object via the execute() method, then @@ -373,30 +373,25 @@ class Compiled(ClauseVisitor): # in a result set is not performance-wise any different than specifying limit=1 # else we'd have to construct a copy of the select() object with the limit # installed (else if we change the existing select, not threadsafe) - row = self.execute(*multiparams, **params).fetchone() - if row is not None: - return row[0] - else: - return None + r = self.execute(*multiparams, **params) + row = r.fetchone() + try: + if row is not None: + return row[0] + else: + return None + finally: + r.close() class Executor(object): - """handles the compilation/execution of a ClauseElement within the context of a particular AbtractEngine. This - AbstractEngine will usually be a SQLEngine or ConnectionProxy.""" + """context-sensitive executor for the using() function.""" def __init__(self, clauseelement, abstractengine=None): self.engine=abstractengine self.clauseelement = clauseelement def execute(self, *multiparams, **params): - return self.compile(*multiparams, **params).execute(*multiparams, **params) + return self.clauseelement.execute_using(self.engine) def scalar(self, *multiparams, **params): - return self.compile(*multiparams, **params).scalar(*multiparams, **params) - def compile(self, *multiparams, **params): - if len(multiparams): - bindparams = multiparams[0] - else: - bindparams = params - compiler = self.engine.compiler(self.clauseelement, bindparams) - compiler.compile() - return compiler + return self.clauseelement.scalar_using(self.engine) class ClauseElement(object): """base class for elements of a programmatically constructed SQL expression.""" @@ -454,26 +449,52 @@ class ClauseElement(object): else: return None - engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.") + engine = property(lambda s: s._find_engine(), doc="attempts to locate a Engine within this ClauseElement structure, or returns None if none found.") def using(self, abstractengine): return Executor(self, abstractengine) + + def execute_using(self, engine, *multiparams, **params): + compile_params = self._conv_params(*multiparams, **params) + return self.compile(engine=engine, parameters=compile_params).execute(*multiparams, **params) + def scalar_using(self, engine, *multiparams, **params): + compile_params = self._conv_params(*multiparams, **params) + return self.compile(engine=engine, parameters=compile_params).scalar(*multiparams, **params) + def _conv_params(self, *multiparams, **params): + if len(multiparams): + return multiparams[0] + else: + return params + def compile(self, engine=None, parameters=None, compiler=None, dialect=None): + """compiles this SQL expression. + + Uses the given Compiler, or the given AbstractDialect or Engine to create a Compiler. If no compiler + arguments are given, tries to use the underlying Engine this ClauseElement is bound + to to create a Compiler, if any. Finally, if there is no bound Engine, uses an ANSIDialect + to create a default Compiler. - def compile(self, engine = None, parameters = None, typemap=None, compiler=None): - """compiles this SQL expression using its underlying SQLEngine to produce - a Compiled object. If no engine can be found, an ANSICompiler is used with no engine. bindparams is a dictionary representing the default bind parameters to be used with - the statement. """ + the statement. if the bindparams is a list, it is assumed to be a list of dictionaries + and the first dictionary in the list is used with which to compile against. + The bind parameters can in some cases determine the output of the compilation, such as for UPDATE + and INSERT statements the bind parameters that are present determine the SET and VALUES clause of + those statements. + """ + + if (isinstance(parameters, list) or isinstance(parameters, tuple)): + parameters = parameters[0] if compiler is None: - if engine is not None: + if dialect is not None: + compiler = dialect.compiler(self, parameters) + elif engine is not None: compiler = engine.compiler(self, parameters) elif self.engine is not None: compiler = self.engine.compiler(self, parameters) if compiler is None: import sqlalchemy.ansisql as ansisql - compiler = ansisql.ANSICompiler(self, parameters=parameters) + compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters) compiler.compile() return compiler @@ -481,10 +502,10 @@ class ClauseElement(object): return str(self.compile()) def execute(self, *multiparams, **params): - return self.using(self.engine).execute(*multiparams, **params) + return self.execute_using(self.engine, *multiparams, **params) def scalar(self, *multiparams, **params): - return self.using(self.engine).scalar(*multiparams, **params) + return self.scalar_using(self.engine, *multiparams, **params) def __and__(self, other): return and_(self, other) @@ -543,7 +564,7 @@ class CompareMixin(object): def __div__(self, other): return self._operate('/', other) def __mod__(self, other): - return self._operate('%', other) + return self._operate('%', other) def __truediv__(self, other): return self._operate('/', other) def _bind_param(self, obj): @@ -554,11 +575,11 @@ class CompareMixin(object): return BooleanExpression(self._compare_self(), null(), 'IS') elif operator == '!=': return BooleanExpression(self._compare_self(), null(), 'IS NOT') - return BooleanExpression(self._compare_self(), null(), 'IS') else: raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") elif _is_literal(obj): obj = self._bind_param(obj) + return BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj)) def _operate(self, operator, obj): if _is_literal(obj): @@ -588,24 +609,43 @@ class Selectable(ClauseElement): return True class ColumnElement(Selectable, CompareMixin): - """represents a column element within the list of a Selectable's columns. Provides - default implementations for the things a "column" needs, including a "primary_key" flag, - a "foreign_key" accessor, an "original" accessor which represents the ultimate column - underlying a string of labeled/select-wrapped columns, and "columns" which returns a list - of the single column, providing the same list-based interface as a FromClause.""" - primary_key = property(lambda self:getattr(self, '_primary_key', False)) - foreign_key = property(lambda self:getattr(self, '_foreign_key', False)) - original = property(lambda self:getattr(self, '_original', self)) - parent = property(lambda self:getattr(self, '_parent', self)) - columns = property(lambda self:[self]) + """represents a column element within the list of a Selectable's columns. + A ColumnElement can either be directly associated with a TableClause, or + a free-standing textual column with no table, or is a "proxy" column, indicating + it is placed on a Selectable such as an Alias or Select statement and ultimately corresponds + to a TableClause-attached column (or in the case of a CompositeSelect, a proxy ColumnElement + may correspond to several TableClause-attached columns).""" + + primary_key = property(lambda self:getattr(self, '_primary_key', False), doc="primary key flag. indicates if this Column represents part or whole of a primary key.") + foreign_key = property(lambda self:getattr(self, '_foreign_key', False), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.") + columns = property(lambda self:[self], doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.") + + def _get_orig_set(self): + try: + return self.__orig_set + except AttributeError: + self.__orig_set = sets.Set([self]) + return self.__orig_set + def _set_orig_set(self, s): + if len(s) == 0: + s.add(self) + self.__orig_set = s + orig_set = property(_get_orig_set, _set_orig_set,doc="""a Set containing TableClause-bound, non-proxied ColumnElements for which this ColumnElement is a proxy. In all cases except for a column proxied from a Union (i.e. CompoundSelect), this set will be just one element.""") + + def shares_lineage(self, othercolumn): + """returns True if the given ColumnElement has a common ancestor to this ColumnElement.""" + for c in self.orig_set: + if c in othercolumn.orig_set: + return True + else: + return False def _make_proxy(self, selectable, name=None): """creates a new ColumnElement representing this ColumnElement as it appears in the select list - of an enclosing selectable. The default implementation returns a ColumnClause if a name is given, - else just returns self. This has various mechanics with schema.Column and sql.Label so that - Column objects as well as non-column objects like Function and BinaryClause can both appear in the - select list of an enclosing selectable.""" + of a descending selectable. The default implementation returns a ColumnClause if a name is given, + else just returns self.""" if name is not None: co = ColumnClause(name, selectable) + co.orig_set = self.orig_set selectable.columns[name]= co return co else: @@ -615,16 +655,17 @@ class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" def __init__(self, from_name = None): self.from_name = self.name = from_name + def _display_name(self): + if self.named_with_column(): + return self.name + else: + return None + displayname = property(_display_name) def _get_from_objects(self): # this could also be [self], at the moment it doesnt matter to the Select object return [] def default_order_by(self): - if not self.engine.default_ordering: - return None - elif self.oid_column is not None: - return [self.oid_column] - else: - return self.primary_key + return [self.oid_column] def accept_visitor(self, visitor): visitor.visit_fromclause(self) def count(self, whereclause=None, **params): @@ -635,6 +676,9 @@ class FromClause(Selectable): return Join(self, right, isouter = True, *args, **kwargs) def alias(self, name=None): return Alias(self, name) + def named_with_column(self): + """True if the name of this FromClause may be prepended to a column in a generated SQL statement""" + return False def _locate_oid_column(self): """subclasses override this to return an appropriate OID column""" return None @@ -642,18 +686,24 @@ class FromClause(Selectable): if not hasattr(self, '_oid_column'): self._oid_column = self._locate_oid_column() return self._oid_column - def _get_col_by_original(self, column, raiseerr=True): - """given a column which is a schema.Column object attached to a schema.Table object - (i.e. an "original" column), return the Column object from this - Selectable which corresponds to that original Column, or None if this Selectable - does not contain the column.""" - try: - return self.original_columns[column.original] - except KeyError: + def corresponding_column(self, column, raiseerr=True, keys_ok=False): + """given a ColumnElement, return the ColumnElement object from this + Selectable which corresponds to that original Column via a proxy relationship.""" + for c in column.orig_set: + try: + return self.original_columns[c] + except KeyError: + pass + else: + if keys_ok: + try: + return self.c[column.key] + except KeyError: + pass if not raiseerr: return None else: - raise InvalidRequestError("cant get orig for " + str(column) + " with table " + column.table.name + " from table " + self.name) + raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(column.table), self.name)) def _get_exported_attribute(self, name): try: @@ -665,10 +715,12 @@ class FromClause(Selectable): c = property(lambda s:s._get_exported_attribute('_columns')) primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys')) - original_columns = property(lambda s:s._get_exported_attribute('_orig_cols')) + original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'), doc="a dictionary mapping an original Table-bound column to a proxied column in this FromClause.") oid_column = property(_get_oid_column) def _export_columns(self): + """this method is called the first time any of the "exported attrbutes" are called. it receives from the Selectable + a list of all columns to be exported and creates "proxy" columns for each one.""" if hasattr(self, '_columns'): # TODO: put a mutex here ? this is a key place for threading probs return @@ -681,9 +733,11 @@ class FromClause(Selectable): if column.is_selectable(): for co in column.columns: cp = self._proxy_column(co) - self._orig_cols[co.original] = cp - if getattr(self, 'oid_column', None): - self._orig_cols[self.oid_column.original] = self.oid_column + for ci in cp.orig_set: + self._orig_cols[ci] = cp + if self.oid_column is not None: + for ci in self.oid_column.orig_set: + self._orig_cols[ci] = self.oid_column def _exportable_columns(self): return [] def _proxy_column(self, column): @@ -702,8 +756,8 @@ class BindParamClause(ClauseElement, CompareMixin): return [] def copy_container(self): return BindParamClause(self.key, self.value, self.shortname, self.type) - def typeprocess(self, value, engine): - return self.type.engine_impl(engine).convert_bind_param(value, engine) + def typeprocess(self, value, dialect): + return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) def compare(self, other): """compares this BindParamClause to the given clause. @@ -720,8 +774,9 @@ class TypeClause(ClauseElement): self.type = type def accept_visitor(self, visitor): visitor.visit_typeclause(self) - def _get_from_objects(self): - return [] + def _get_from_objects(self): + return [] + class TextClause(ClauseElement): """represents literal a SQL text fragment. public constructor is the text() function. @@ -909,7 +964,8 @@ class FunctionGenerator(object): self.__names.append(name) return self def __call__(self, *c, **kwargs): - return Function(self.__names[-1], packagenames=self.__names[0:-1], engine=self.__engine, *c, **kwargs) + kwargs.setdefault('engine', self.__engine) + return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **kwargs) class BinaryClause(ClauseElement): """represents two clauses with an operator in between""" @@ -956,15 +1012,13 @@ class Join(FromClause): def __init__(self, left, right, onclause=None, isouter = False): self.left = left self.right = right - # TODO: if no onclause, do NATURAL JOIN if onclause is None: self.onclause = self._match_primaries(left, right) else: self.onclause = onclause self.isouter = isouter - name = property(lambda self: "Join on %s, %s" % (self.left.name, self.right.name)) - + name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name) def _locate_oid_column(self): return self.left.oid_column @@ -981,15 +1035,15 @@ class Join(FromClause): crit = [] for fk in secondary.foreign_keys: if fk.references(primary): - crit.append(primary._get_col_by_original(fk.column) == fk.parent) + crit.append(primary.corresponding_column(fk.column) == fk.parent) self.foreignkey = fk.parent if primary is not secondary: for fk in primary.foreign_keys: if fk.references(secondary): - crit.append(secondary._get_col_by_original(fk.column) == fk.parent) + crit.append(secondary.corresponding_column(fk.column) == fk.parent) self.foreignkey = fk.parent if len(crit) == 0: - raise ArgumentError("Cant find any foreign key relationships between '%s' and '%s'" % (primary.name, secondary.name)) + raise exceptions.ArgumentError("Cant find any foreign key relationships between '%s' and '%s'" % (primary.name, secondary.name)) elif len(crit) == 1: return (crit[0]) else: @@ -1037,12 +1091,13 @@ class Alias(FromClause): self.original = baseselectable self.selectable = selectable if alias is None: - n = getattr(self.original, 'name', None) - if n is None: - n = 'anon' - elif len(n) > 15: - n = n[0:15] - alias = n + "_" + hex(random.randint(0, 65535))[2:] + if self.original.named_with_column(): + alias = getattr(self.original, 'name', None) + if alias is None: + alias = 'anon' + elif len(alias) > 15: + alias = alias[0:15] + alias = alias + "_" + hex(random.randint(0, 65535))[2:] self.name = alias def _locate_oid_column(self): @@ -1050,8 +1105,11 @@ class Alias(FromClause): return self.selectable.oid_column._make_proxy(self) else: return None - + + def named_with_column(self): + return True def _exportable_columns(self): + #return self.selectable._exportable_columns() return self.selectable.columns def accept_visitor(self, visitor): @@ -1076,10 +1134,8 @@ class Label(ColumnElement): self.type = sqltypes.to_instance(type) obj.parens=True key = property(lambda s: s.name) - _label = property(lambda s: s.name) - original = property(lambda s:s.obj.original) - parent = property(lambda s:s.obj.parent) + orig_set = property(lambda s:s.obj.orig_set) def accept_visitor(self, visitor): self.obj.accept_visitor(visitor) visitor.visit_label(self) @@ -1091,19 +1147,20 @@ class Label(ColumnElement): class ColumnClause(ColumnElement): """represents a textual column clause in a SQL statement. May or may not be bound to an underlying Selectable.""" - def __init__(self, text, selectable=None, type=None): - self.key = self.name = self.text = text + def __init__(self, text, selectable=None, type=None, hidden=False): + self.key = self.name = text self.table = selectable self.type = sqltypes.to_instance(type) + self.hidden = hidden self.__label = None def _get_label(self): if self.__label is None: - if self.table is not None and self.table.name is not None: - self.__label = self.table.name + "_" + self.text + if self.table is not None and self.table.named_with_column(): + self.__label = self.table.name + "_" + self.name + if self.table.c.has_key(self.__label) or len(self.__label) >= 30: + self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:] else: - self.__label = self.text - if (self.table is not None and self.table.c.has_key(self.__label)) or len(self.__label) >= 30: - self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:] + self.__label = self.name return self.__label _label = property(_get_label) def accept_visitor(self, visitor): @@ -1113,21 +1170,19 @@ class ColumnClause(ColumnElement): for example, this could translate the column "name" from a Table object to an Alias of a Select off of that Table object.""" - return selectable._get_col_by_original(self.original, False) + return selectable.corresponding_column(self.original, False) def _get_from_objects(self): if self.table is not None: return [self.table] else: return [] def _bind_param(self, obj): - if self.table.name is None: - return BindParamClause(self.text, obj, shortname=self.text, type=self.type) - else: - return BindParamClause(self._label, obj, shortname = self.text, type=self.type) + return BindParamClause(self._label, obj, shortname = self.name, type=self.type) def _make_proxy(self, selectable, name = None): - c = ColumnClause(name or self.text, selectable) - c._original = self.original - selectable.columns[c.name] = c + c = ColumnClause(name or self.name, selectable, hidden=self.hidden) + c.orig_set = self.orig_set + if not self.hidden: + selectable.columns[c.name] = c return c def _compare_type(self, obj): return self.type @@ -1144,29 +1199,25 @@ class TableClause(FromClause): self._primary_key = [] for c in columns: self.append_column(c) + self._oid_column = ColumnClause('oid', self, hidden=True) indexes = property(lambda s:s._indexes) - + + def named_with_column(self): + return True def append_column(self, c): - self._columns[c.text] = c + self._columns[c.name] = c c.table = self def _locate_oid_column(self): - if self.engine is None: - return None - if self.engine.oid_column_name() is not None: - _oid_column = schema.Column(self.engine.oid_column_name(), sqltypes.Integer, hidden=True) - _oid_column._set_parent(self) - self._orig_columns()[_oid_column.original] = _oid_column - return _oid_column - else: - return None + return self._oid_column def _orig_columns(self): try: return self._orig_cols except AttributeError: self._orig_cols= {} for c in self.columns: - self._orig_cols[c.original] = c + for ci in c.orig_set: + self._orig_cols[ci] = c return self._orig_cols columns = property(lambda s:s._columns) c = property(lambda s:s._columns) @@ -1177,6 +1228,7 @@ class TableClause(FromClause): def _clear(self): """clears all attributes on this TableClause so that new items can be added again""" self.columns.clear() + self.indexes.clear() self.foreign_keys[:] = [] self.primary_key[:] = [] try: @@ -1240,6 +1292,7 @@ class SelectBaseMixin(object): class CompoundSelect(SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): + SelectBaseMixin.__init__(self) self.keyword = keyword self.selects = selects self.use_labels = kwargs.pop('use_labels', False) @@ -1251,21 +1304,34 @@ class CompoundSelect(SelectBaseMixin, FromClause): s.order_by(None) self.group_by(*kwargs.get('group_by', [None])) self.order_by(*kwargs.get('order_by', [None])) + self._col_map = {} +# name = property(lambda s:s.keyword + " statement") + def _foo(self): + raise "this is a temporary assertion while we refactor SQL to not call 'name' on non-table Selectables" + name = property(lambda s:s._foo()) #"SELECT statement") + def _locate_oid_column(self): return self.selects[0].oid_column - def _exportable_columns(self): for s in self.selects: for c in s.c: yield c - def _proxy_column(self, column): if self.use_labels: - return column._make_proxy(self, name=column._label) + col = column._make_proxy(self, name=column._label) else: - return column._make_proxy(self, name=column.name) + col = column._make_proxy(self, name=column.name) + try: + colset = self._col_map[col.name] + except KeyError: + colset = sets.Set() + self._col_map[col.name] = colset + [colset.add(c) for c in col.orig_set] + col.orig_set = colset + return col + def accept_visitor(self, visitor): self.order_by_clause.accept_visitor(visitor) self.group_by_clause.accept_visitor(visitor) @@ -1284,9 +1350,9 @@ class Select(SelectBaseMixin, FromClause): """represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + SelectBaseMixin.__init__(self) self._froms = util.OrderedDict() self.use_labels = use_labels - self.name = None self.whereclause = None self.having = None self._engine = engine @@ -1331,8 +1397,11 @@ class Select(SelectBaseMixin, FromClause): for f in from_obj: self.append_from(f) - - + + def _foo(self): + raise "this is a temporary assertion while we refactor SQL to not call 'name' on non-table Selectables" + name = property(lambda s:s._foo()) #"SELECT statement") + class CorrelatedVisitor(ClauseVisitor): """visits a clause, locates any Select clauses, and tells them that they should correlate their FROM list to that of their parent.""" @@ -1401,6 +1470,9 @@ class Select(SelectBaseMixin, FromClause): fromclause._process_from_dict(self._froms, True) def _locate_oid_column(self): for f in self._froms.values(): + if f is self: + # TODO: why would we be in our own _froms list ? + raise exceptions.AssertionError("Select statement should not be in its own _froms list") oid = f.oid_column if oid is not None: return oid @@ -1429,16 +1501,8 @@ class Select(SelectBaseMixin, FromClause): return union(self, other, **kwargs) def union_all(self, other, **kwargs): return union_all(self, other, **kwargs) - -# def scalar(self, *multiparams, **params): - # need to set limit=1, but only in this thread. - # we probably need to make a copy of the select(). this - # is expensive. I think cursor.fetchone(), then discard remaining results - # should be fine with most DBs - # for now use base scalar() method - def _find_engine(self): - """tries to return a SQLEngine, either explicitly set in this object, or searched + """tries to return a Engine, either explicitly set in this object, or searched within the from clauses for one""" if self._engine is not None: @@ -1454,7 +1518,6 @@ class Select(SelectBaseMixin, FromClause): class UpdateBase(ClauseElement): """forms the base for INSERT, UPDATE, and DELETE statements.""" - def _process_colparams(self, parameters): """receives the "values" of an INSERT or UPDATE statement and constructs appropriate ind parameters.""" @@ -1483,17 +1546,14 @@ class UpdateBase(ClauseElement): except KeyError: del parameters[key] return parameters - def _find_engine(self): - return self._engine - + return self.table.engine class Insert(UpdateBase): def __init__(self, table, values=None, **params): self.table = table self.select = None self.parameters = self._process_colparams(values) - self._engine = self.table.engine def accept_visitor(self, visitor): if self.select is not None: @@ -1506,7 +1566,6 @@ class Update(UpdateBase): self.table = table self.whereclause = whereclause self.parameters = self._process_colparams(values) - self._engine = self.table.engine def accept_visitor(self, visitor): if self.whereclause is not None: @@ -1517,7 +1576,6 @@ class Delete(UpdateBase): def __init__(self, table, whereclause, **params): self.table = table self.whereclause = whereclause - self._engine = self.table.engine def accept_visitor(self, visitor): if self.whereclause is not None: |