diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-02-25 07:12:50 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-02-25 07:12:50 +0000 |
commit | 72dd2b08beb9803269983aa220e75b44007e5158 (patch) | |
tree | 16f80b5f869ba68ae17e2fcbe9b18f1542b22e84 /lib/sqlalchemy/sql.py | |
parent | 5b81c1a2d0915d95d9928ffaaf81af814cf4ec3e (diff) | |
download | sqlalchemy-72dd2b08beb9803269983aa220e75b44007e5158.tar.gz |
merged sql_rearrangement branch , refactors sql package to work standalone with
clause elements including tables and columns, schema package deals with "physical"
representations
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 316 |
1 files changed, 95 insertions, 221 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index cbd9a82f3..8ebf7624e 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -13,7 +13,7 @@ from exceptions import * import string, re, random types = __import__('types') -__all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] +__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] def desc(column): """returns a descending ORDER BY clause element, e.g.: @@ -160,11 +160,15 @@ def label(name, obj): """returns a Label object for the given selectable, used in the column list for a select statement.""" return Label(name, obj) -def column(table, text): - """returns a textual column clause, relative to a table. this differs from using straight text - or text() in that the column is treated like a regular column, i.e. gets added to a Selectable's list - of columns.""" - return ColumnClause(text, table) +def column(text, table=None, type=None): + """returns a textual column clause, relative to a table. this is also the primitive version of + a schema.Column which is a subclass. """ + return ColumnClause(text, table, type) + +def table(name, *columns): + """returns a table clause. this is a primitive version of the schema.Table object, which is a subclass + of this object.""" + return TableClause(name, *columns) def bindparam(key, value = None, type=None): """creates a bind parameter clause with the given key. @@ -172,7 +176,7 @@ def bindparam(key, value = None, type=None): An optional default value can be specified by the value parameter, and the optional type parameter is a sqlalchemy.types.TypeEngine object which indicates bind-parameter and result-set translation for this bind parameter.""" - if isinstance(key, schema.Column): + if isinstance(key, ColumnClause): return BindParamClause(key.name, value, type=key.type) else: return BindParamClause(key, value, type=type) @@ -190,7 +194,7 @@ def text(text, engine=None, *args, **kwargs): text - the text of the SQL statement to be created. use :<param> to specify bind parameters; they will be compiled to their engine-specific format. - engine - the engine to be used for this text query. Alternatively, call the + engine - an optional engine to be used for this text query. Alternatively, call the text() method off the engine directly. bindparams - a list of bindparam() instances which can be used to define the @@ -222,15 +226,15 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem) + return not isinstance(element, ClauseElement) def is_column(col): - return isinstance(col, schema.Column) or isinstance(col, ColumnElement) + return isinstance(col, ColumnElement) -class ClauseVisitor(schema.SchemaVisitor): - """builds upon SchemaVisitor to define the visiting of SQL statement elements in - addition to Schema elements.""" - def visit_columnclause(self, column):pass +class ClauseVisitor(object): + """Defines the visiting of ClauseElements.""" + def visit_column(self, column):pass + def visit_table(self, column):pass def visit_fromclause(self, fromclause):pass def visit_bindparam(self, bindparam):pass def visit_textclause(self, textclause):pass @@ -309,18 +313,6 @@ class Compiled(ClauseVisitor): class ClauseElement(object): """base class for elements of a programmatically constructed SQL expression.""" - def hash_key(self): - """returns a string that uniquely identifies the concept this ClauseElement - represents. - - two ClauseElements can have the same value for hash_key() iff they both correspond to - the exact same generated SQL. This allows the hash_key() values of a collection of - ClauseElements to be constructed into a larger identifying string for the purpose of - caching a SQL expression. - - Note that since ClauseElements may be mutable, the hash_key() value is subject to - change if the underlying structure of the ClauseElement changes.""" - raise NotImplementedError(repr(self)) def _get_from_objects(self): """returns objects represented in this ClauseElement that should be added to the FROM list of a query.""" @@ -357,19 +349,24 @@ class ClauseElement(object): return False def _find_engine(self): + """default strategy for locating an engine within the clause element. + relies upon a local engine property, or looks in the "from" objects which + ultimately have to contain Tables or TableClauses. """ try: if self._engine is not None: return self._engine except AttributeError: pass for f in self._get_from_objects(): + if f is self: + continue engine = f.engine if engine is not None: return engine else: return None - engine = property(lambda s: s._find_engine()) + engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.") def compile(self, engine = None, parameters = None, typemap=None): """compiles this SQL expression using its underlying SQLEngine to produce @@ -380,16 +377,13 @@ class ClauseElement(object): engine = self.engine if engine is None: - raise InvalidRequestError("no SQLEngine could be located within this ClauseElement.") + import sqlalchemy.ansisql as ansisql + engine = ansisql.engine() return engine.compile(self, parameters=parameters, typemap=typemap) def __str__(self): - e = self.engine - if e is None: - import sqlalchemy.ansisql as ansisql - e = ansisql.engine() - return str(self.compile(e)) + return str(self.compile()) def execute(self, *multiparams, **params): """compiles and executes this SQL expression using its underlying SQLEngine. the @@ -425,6 +419,7 @@ class ClauseElement(object): return not_(self) class CompareMixin(object): + """defines comparison operations for ClauseElements.""" def __lt__(self, other): return self._compare('<', other) def __le__(self, other): @@ -500,19 +495,15 @@ class Selectable(ClauseElement): def accept_visitor(self, visitor): raise NotImplementedError(repr(self)) - def is_selectable(self): return True - def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _group_parenthesized(self): """indicates if this Selectable requires parenthesis when grouped into a compound statement""" return True - class ColumnElement(Selectable, CompareMixin): """represents a column element within the list of a Selectable's columns. Provides default implementations for the things a "column" needs, including a "primary_key" flag, @@ -552,8 +543,6 @@ class FromClause(Selectable): return [self.oid_column] else: return self.primary_key - def hash_key(self): - return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name)) def accept_visitor(self, visitor): visitor.visit_fromclause(self) def count(self, whereclause=None, **params): @@ -627,8 +616,6 @@ class BindParamClause(ClauseElement, CompareMixin): visitor.visit_bindparam(self) def _get_from_objects(self): return [] - def hash_key(self): - return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname)) def typeprocess(self, value, engine): return self._get_convert_type(engine).convert_bind_param(value, engine) def compare(self, other): @@ -674,8 +661,6 @@ class TextClause(ClauseElement): for item in self.bindparams.values(): item.accept_visitor(visitor) visitor.visit_textclause(self) - def hash_key(self): - return "TextClause(%s)" % repr(self.text) def _get_from_objects(self): return [] @@ -686,8 +671,6 @@ class Null(ClauseElement): visitor.visit_null(self) def _get_from_objects(self): return [] - def hash_key(self): - return "Null" class ClauseList(ClauseElement): """describes a list of clauses. by default, is comma-separated, @@ -698,8 +681,6 @@ class ClauseList(ClauseElement): if c is None: continue self.append(c) self.parens = kwargs.get('parens', False) - def hash_key(self): - return string.join([c.hash_key() for c in self.clauses], ",") def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return ClauseList(parens=self.parens, *clauses) @@ -753,8 +734,6 @@ class CompoundClause(ClauseList): for c in self.clauses: f += c._get_from_objects() return f - def hash_key(self): - return string.join([c.hash_key() for c in self.clauses], self.operator or " ") def compare(self, other): """compares this CompoundClause to the given item. @@ -794,8 +773,6 @@ class Function(ClauseList, ColumnElement): return BindParamClause(self.name, obj, shortname=self.name, type=self.type) def select(self): return select([self]) - def hash_key(self): - return self.name + "(" + string.join([c.hash_key() for c in self.clauses], ", ") + ")" def _compare_type(self, obj): return self.type @@ -811,8 +788,6 @@ class BinaryClause(ClauseElement): return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator) def _get_from_objects(self): return self.left._get_from_objects() + self.right._get_from_objects() - def hash_key(self): - return self.left.hash_key() + (self.operator or " ") + self.right.hash_key() def accept_visitor(self, visitor): self.left.accept_visitor(visitor) self.right.accept_visitor(visitor) @@ -879,16 +854,9 @@ class Join(FromClause): return and_(*crit) def _group_parenthesized(self): - """indicates if this Selectable requires parenthesis when grouped into a compound - statement""" return True - - def hash_key(self): - return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) - def select(self, whereclauses = None, **params): return select([self.left, self.right], whereclauses, from_obj=[self], **params) - def accept_visitor(self, visitor): self.left.accept_visitor(visitor) self.right.accept_visitor(visitor) @@ -941,9 +909,6 @@ class Alias(FromClause): def _exportable_columns(self): return self.selectable.columns - def hash_key(self): - return "Alias(%s, %s)" % (self.selectable.hash_key(), repr(self.name)) - def accept_visitor(self, visitor): self.selectable.accept_visitor(visitor) visitor.visit_alias(self) @@ -975,35 +940,27 @@ class Label(ColumnElement): return self.obj._get_from_objects() def _make_proxy(self, selectable, name = None): return self.obj._make_proxy(selectable, name=self.name) - - def hash_key(self): - return "Label(%s, %s)" % (self.name, self.obj.hash_key()) class ColumnClause(ColumnElement): - """represents a textual column clause in a SQL statement. allows the creation - of an additional ad-hoc column that is compiled against a particular table.""" - - def __init__(self, text, selectable=None): - self.text = text + """represents a textual column clause in a SQL statement. May or may not + be bound to an underlying Selectable.""" + def __init__(self, text, selectable=None, type=None): + self.key = self.name = self.text = text self.table = selectable - self.type = sqltypes.NullTypeEngine() - - name = property(lambda self:self.text) - key = property(lambda self:self.text) - _label = property(lambda self:self.text) - - def accept_visitor(self, visitor): - visitor.visit_columnclause(self) - - def hash_key(self): + self.type = type or sqltypes.NullTypeEngine() + def _get_label(self): if self.table is not None: - return "ColumnClause(%s, %s)" % (self.text, util.hash_key(self.table)) + return self.table.name + "_" + self.text else: - return "ColumnClause(%s)" % self.text - + return self.text + _label = property(_get_label) + def accept_visitor(self, visitor): + visitor.visit_column(self) def _get_from_objects(self): - return [] - + if self.table is not None: + return [self.table] + else: + return [] def _bind_param(self, obj): if self.table.name is None: return BindParamClause(self.text, obj, shortname=self.text, type=self.type) @@ -1013,79 +970,35 @@ class ColumnClause(ColumnElement): c = ColumnClause(name or self.text, selectable) selectable.columns[c.key] = c return c - -class ColumnImpl(ColumnElement): - """gets attached to a schema.Column object.""" - - def __init__(self, column): - self.column = column - self.name = column.name - - if column.table.name: - self._label = column.table.name + "_" + self.column.name - else: - self._label = self.column.name - - engine = property(lambda s: s.column.engine) - default_label = property(lambda s:s._label) - original = property(lambda self:self.column.original) - parent = property(lambda self:self.column.parent) - columns = property(lambda self:[self.column]) - - def label(self, name): - return Label(name, self.column) - - def copy_container(self): - return self.column - - def compare(self, other): - """compares this ColumnImpl's column to the other given Column""" - return self.column is other - + def _compare_type(self, obj): + return self.type def _group_parenthesized(self): return False - - def _get_from_objects(self): - return [self.column.table] - - def _bind_param(self, obj): - if self.column.table.name is None: - return BindParamClause(self.name, obj, shortname = self.name, type = self.column.type) - else: - return BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name, type = self.column.type) - def _compare_self(self): - """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to - just return self""" - return self.column - def _compare_type(self, obj): - return self.column.type - - def compile(self, engine = None, parameters = None, typemap=None): - if engine is None: - engine = self.engine - if engine is None: - raise InvalidRequestError("no SQLEngine could be located within this ClauseElement.") - return engine.compile(self.column, parameters=parameters, typemap=typemap) -class TableImpl(FromClause): - """attached to a schema.Table to provide it with a Selectable interface - as well as other functions - """ - - def __init__(self, table): - self.table = table - self.id = self.table.name +class TableClause(FromClause): + def __init__(self, name, *columns): + super(TableClause, self).__init__(name) + self.name = self.id = self.fullname = name + self._columns = util.OrderedProperties() + self._foreign_keys = [] + self._primary_key = [] + for c in columns: + self.append_column(c) + def append_column(self, c): + self._columns[c.text] = c + c.table = self def _oid_col(self): + if self.engine is None: + return None # OID remains a little hackish so far if not hasattr(self, '_oid_column'): - if self.table.engine.oid_column_name() is not None: - self._oid_column = schema.Column(self.table.engine.oid_column_name(), sqltypes.Integer, hidden=True) - self._oid_column._set_parent(self.table) + if self.engine.oid_column_name() is not None: + self._oid_column = schema.Column(self.engine.oid_column_name(), sqltypes.Integer, hidden=True) + self._oid_column._set_parent(self) else: self._oid_column = None return self._oid_column - def _orig_columns(self): try: return self._orig_cols @@ -1097,47 +1010,52 @@ class TableImpl(FromClause): if oid is not None: self._orig_cols[oid.original] = oid return self._orig_cols - - oid_column = property(_oid_col) - engine = property(lambda s: s.table.engine) - columns = property(lambda self: self.table.columns) - primary_key = property(lambda self:self.table.primary_key) - foreign_keys = property(lambda self:self.table.foreign_keys) + columns = property(lambda s:s._columns) + c = property(lambda s:s._columns) + primary_key = property(lambda s:s._primary_key) + foreign_keys = property(lambda s:s._foreign_keys) original_columns = property(_orig_columns) + oid_column = property(_oid_col) + + def _clear(self): + """clears all attributes on this TableClause so that new items can be added again""" + self.columns.clear() + self.foreign_keys[:] = [] + self.primary_key[:] = [] + try: + delattr(self, '_orig_cols') + except AttributeError: + pass + def accept_visitor(self, visitor): + visitor.visit_table(self) def _exportable_columns(self): raise NotImplementedError() - def _group_parenthesized(self): return False - def _process_from_dict(self, data, asfrom): for f in self._get_from_objects(): data.setdefault(f.id, f) if asfrom: - data[self.id] = self.table + data[self.id] = self def count(self, whereclause=None, **params): - return select([func.count(1).label('count')], whereclause, from_obj=[self.table], **params) + return select([func.count(1).label('count')], whereclause, from_obj=[self], **params) def join(self, right, *args, **kwargs): - return Join(self.table, right, *args, **kwargs) + return Join(self, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): - return Join(self.table, right, isouter = True, *args, **kwargs) + return Join(self, right, isouter = True, *args, **kwargs) def alias(self, name=None): - return Alias(self.table, name) + return Alias(self, name) def select(self, whereclause = None, **params): - return select([self.table], whereclause, **params) + return select([self], whereclause, **params) def insert(self, values = None): - return insert(self.table, values=values) + return insert(self, values=values) def update(self, whereclause = None, values = None): - return update(self.table, whereclause, values) + return update(self, whereclause, values) def delete(self, whereclause = None): - return delete(self.table, whereclause) - def create(self, **params): - self.table.engine.create(self.table) - def drop(self, **params): - self.table.engine.drop(self.table) + return delete(self, whereclause) def _get_from_objects(self): - return [self.table] + return [self] class SelectBaseMixin(object): """base class for Select and CompoundSelects""" @@ -1191,11 +1109,6 @@ class CompoundSelect(SelectBaseMixin, FromClause): order_by = kwargs.get('order_by', None) if order_by: self.order_by(*order_by) - def hash_key(self): - return "CompoundSelect(%s)" % string.join( - [util.hash_key(s) for s in self.selects] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'keyword']], - ",") def _exportable_columns(self): return self.selects[0].columns def _proxy_column(self, column): @@ -1271,6 +1184,8 @@ class Select(SelectBaseMixin, FromClause): self.is_where = is_where def visit_compound_select(self, cs): self.visit_select(cs) + def visit_column(self, c):pass + def visit_table(self, c):pass def visit_select(self, select): if select is self.select: return @@ -1288,7 +1203,6 @@ class Select(SelectBaseMixin, FromClause): for f in column._get_from_objects(): f.accept_visitor(self._correlator) column._process_from_dict(self._froms, False) - def _exportable_columns(self): return self._raw_columns def _proxy_column(self, column): @@ -1313,24 +1227,6 @@ class Select(SelectBaseMixin, FromClause): _hash_recursion = util.RecursionStack() - def hash_key(self): - # selects call alot of stuff so we do some "recursion checking" - # to eliminate loops - if Select._hash_recursion.push(self): - return "recursive_select()" - try: - return "Select(%s)" % string.join( - [ - "columns=" + string.join([util.hash_key(c) for c in self._raw_columns],','), - "where=" + util.hash_key(self.whereclause), - "from=" + string.join([util.hash_key(f) for f in self.froms],','), - "having=" + util.hash_key(self.having), - "clauses=" + string.join([util.hash_key(c) for c in self.clauses], ',') - ] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'distinct', 'limit', 'offset']], "," - ) - finally: - Select._hash_recursion.pop(self) - def clear_from(self, id): self.append_from(FromClause(from_name = None, from_key = id)) @@ -1342,7 +1238,7 @@ class Select(SelectBaseMixin, FromClause): fromclause._process_from_dict(self._froms, True) def _get_froms(self): - return [f for f in self._froms.values() if self._correlated is None or not self._correlated.has_key(f.id)] + return [f for f in self._froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f.id))] froms = property(lambda s: s._get_froms()) def accept_visitor(self, visitor): @@ -1388,9 +1284,6 @@ class Select(SelectBaseMixin, FromClause): class UpdateBase(ClauseElement): """forms the base for INSERT, UPDATE, and DELETE statements.""" - def hash_key(self): - return str(id(self)) - def _process_colparams(self, parameters): """receives the "values" of an INSERT or UPDATE statement and constructs appropriate ind parameters.""" @@ -1419,6 +1312,9 @@ class UpdateBase(ClauseElement): except KeyError: del parameters[key] return parameters + + def _find_engine(self): + return self._engine class Insert(UpdateBase): @@ -1457,25 +1353,3 @@ class Delete(UpdateBase): self.whereclause.accept_visitor(visitor) visitor.visit_delete(self) -class IndexImpl(ClauseElement): - - def __init__(self, index): - self.index = index - self.name = index.name - self._engine = self.index.table.engine - - table = property(lambda s: s.index.table) - columns = property(lambda s: s.index.columns) - - def hash_key(self): - return self.index.hash_key() - def accept_visitor(self, visitor): - visitor.visit_index(self.index) - def compare(self, other): - return self.index is other - def create(self): - self._engine.create(self.index) - def drop(self): - self._engine.drop(self.index) - def execute(self): - self.create() |