diff options
-rw-r--r-- | CHANGES | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/maxdb.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sybase.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 40 | ||||
-rw-r--r-- | lib/sqlalchemy/schema.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 75 | ||||
-rw-r--r-- | test/orm/query.py | 12 | ||||
-rw-r--r-- | test/sql/testtypes.py | 28 |
9 files changed, 91 insertions, 94 deletions
@@ -59,11 +59,25 @@ CHANGES - Adjustment to Session's post-flush accounting of newly "clean" objects to better protect against operating on objects as they're asynchronously gc'ed. [ticket:1182] + + - "not equals" comparisons of simple many-to-one relation + to an instance will not drop into an EXISTS clause + and will compare foreign key columns instead. + - removed not-really-working use cases of comparing + a collection to an iterable. Use contains() to test + for collection membership. + - sql - column.in_(someselect) can now be used as a columns-clause expression without the subquery bleeding into the FROM clause [ticket:1074] + + - Further simplified SELECT compilation and its relationship + to result row processing. + + - Direct execution of a union() construct will properly set up + result-row processing. [ticket:1194] - sqlite - Overhauled SQLite date/time bind/result processing to use diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py index 34629b298..0e7310ab6 100644 --- a/lib/sqlalchemy/databases/maxdb.py +++ b/lib/sqlalchemy/databases/maxdb.py @@ -829,7 +829,7 @@ class MaxDBCompiler(compiler.DefaultCompiler): # No ORDER BY in subqueries. if order_by: - if self.is_subquery(select): + if self.is_subquery(): # It's safe to simply drop the ORDER BY if there is no # LIMIT. Right? Other dialects seem to get away with # dropping order. @@ -845,7 +845,7 @@ class MaxDBCompiler(compiler.DefaultCompiler): def get_select_precolumns(self, select): # Convert a subquery's LIMIT to TOP sql = select._distinct and 'DISTINCT ' or '' - if self.is_subquery(select) and select._limit: + if self.is_subquery() and select._limit: if select._offset: raise exc.InvalidRequestError( 'MaxDB does not support LIMIT with an offset.') @@ -855,7 +855,7 @@ class MaxDBCompiler(compiler.DefaultCompiler): def limit_clause(self, select): # The docs say offsets are supported with LIMIT. But they're not. # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate? - if self.is_subquery(select): + if self.is_subquery(): # sub queries need TOP return '' elif select._offset: diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 4c5ad1fd1..42743870a 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -994,7 +994,7 @@ class MSSQLCompiler(compiler.DefaultCompiler): order_by = self.process(select._order_by_clause) # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery(select) or select._limit): + if order_by and (not self.is_subquery() or select._limit): return " ORDER BY " + order_by else: return "" diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py index 5c64ec1ae..b464a3bcb 100644 --- a/lib/sqlalchemy/databases/sybase.py +++ b/lib/sqlalchemy/databases/sybase.py @@ -798,7 +798,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler): order_by = self.process(select._order_by_clause) # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery(select) or select._limit): + if order_by and (not self.is_subquery() or select._limit): return " ORDER BY " + order_by else: return "" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d2f5dae0c..40bca8a11 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -347,18 +347,7 @@ class PropertyLoader(StrategizedProperty): else: return self.prop._optimized_compare(None) elif self.prop.uselist: - if not hasattr(other, '__iter__'): - raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().") - else: - j = self.prop.primaryjoin - if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin - clauses = [] - for o in other: - clauses.append( - sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))])) - ) - return sql.and_(*clauses) + raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") else: return self.prop._optimized_compare(other) @@ -418,25 +407,30 @@ class PropertyLoader(StrategizedProperty): return clause def __negated_contains_or_equals(self, other): + if self.prop.direction == MANYTOONE: + state = attributes.instance_state(other) + strategy = self.prop._get_strategy(strategies.LazyLoader) + if strategy.use_get: + return sql.and_(*[ + sql.or_( + x != + self.prop.mapper._get_committed_state_attr_by_column(state, y), + x == None) + for (x, y) in self.prop.local_remote_pairs]) + criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]) return ~self._criterion_exists(criterion) def __ne__(self, other): - # TODO: simplify MANYTOONE comparsion when - # the 'use_get' flag is enabled - if other is None: if self.prop.direction == MANYTOONE: return sql.or_(*[x!=None for x in self.prop._foreign_keys]) - elif self.prop.uselist: - return self.any() else: - return self.has() - - if self.prop.uselist and not hasattr(other, '__iter__'): - raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object") - - return self.__negated_contains_or_equals(other) + return self._criterion_exists() + elif self.prop.uselist: + raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") + else: + return self.__negated_contains_or_equals(other) def compare(self, op, value, value_is_parent=False): if op == operators.eq: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index df994d689..d66a51de4 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -65,20 +65,20 @@ class SchemaItem(object): def __repr__(self): return "%s()" % self.__class__.__name__ + @property def bind(self): """Return the connectable associated with this SchemaItem.""" m = self.metadata return m and m.bind or None - bind = property(bind) + @property def info(self): try: return self._info except AttributeError: self._info = {} return self._info - info = property(info) def _get_table_key(name, schema): @@ -291,9 +291,9 @@ class Table(SchemaItem, expression.TableClause): def __post_init(self, *args, **kwargs): self._init_items(*args) + @property def key(self): return _get_table_key(self.name, self.schema) - key = property(key) def _set_primary_key(self, pk): if getattr(self, '_primary_key', None) in self.constraints: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2982a1759..573453499 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -174,19 +174,13 @@ class DefaultCompiler(engine.Compiled): def compile(self): self.string = self.process(self.statement) - def process(self, obj, stack=None, **kwargs): - if stack: - self.stack.append(stack) - try: - meth = getattr(self, "visit_%s" % obj.__visit_name__, None) - if meth: - return meth(obj, **kwargs) - finally: - if stack: - self.stack.pop(-1) + def process(self, obj, **kwargs): + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) - def is_subquery(self, select): - return self.stack and self.stack[-1].get('is_subquery') + def is_subquery(self): + return self.stack and self.stack[-1].get('from') def construct_params(self, params=None): """return a dictionary of bind parameter keys and values""" @@ -342,16 +336,9 @@ class DefaultCompiler(engine.Compiled): return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): - stack_entry = {'select':cs} - - if asfrom: - stack_entry['is_subquery'] = True - elif self.stack and self.stack[-1].get('select'): - stack_entry['is_subquery'] = True - self.stack.append(stack_entry) - text = string.join((self.process(c, asfrom=asfrom, parens=False) - for c in cs.selects), + text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i) + for i, c in enumerate(cs.selects)), " " + cs.keyword + " ") group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: @@ -360,8 +347,6 @@ class DefaultCompiler(engine.Compiled): text += self.order_by_clause(cs) text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or "" - self.stack.pop(-1) - if asfrom and parens: return "(" + text + ")" else: @@ -470,28 +455,11 @@ class DefaultCompiler(engine.Compiled): else: return column - def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, **kwargs): + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs): - stack_entry = {'select':select} - prev_entry = self.stack and self.stack[-1] or None - - if asfrom or (prev_entry and 'select' in prev_entry): - stack_entry['is_subquery'] = True - stack_entry['iswrapper'] = iswrapper - if not iswrapper and prev_entry and 'iswrapper' in prev_entry: - column_clause_args = {'result_map':self.result_map} - else: - column_clause_args = {} - elif iswrapper: - column_clause_args = {} - stack_entry['iswrapper'] = True - else: - column_clause_args = {'result_map':self.result_map} - - if self.stack and 'from' in self.stack[-1]: - existingfroms = self.stack[-1]['from'] - else: - existingfroms = None + entry = self.stack and self.stack[-1] or {} + + existingfroms = entry.get('from', None) froms = select._get_display_froms(existingfroms) @@ -499,10 +467,15 @@ class DefaultCompiler(engine.Compiled): # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost -# if existingfroms: -# correlate_froms = correlate_froms.union(existingfroms) - stack_entry['from'] = correlate_froms - self.stack.append(stack_entry) + # if existingfroms: + # correlate_froms = correlate_froms.union(existingfroms) + + if compound_index==1 and not entry or entry.get('iswrapper', False): + column_clause_args = {'result_map':self.result_map} + else: + column_clause_args = {} + + self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet( @@ -520,13 +493,9 @@ class DefaultCompiler(engine.Compiled): text += self.get_select_precolumns(select) text += ', '.join(inner_columns) - from_strings = [] - for f in froms: - from_strings.append(self.process(f, asfrom=True)) - if froms: text += " \nFROM " - text += ', '.join(from_strings) + text += ', '.join(self.process(f, asfrom=True) for f in froms) else: text += self.default_from() diff --git a/test/orm/query.py b/test/orm/query.py index 151bada63..567ca317c 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -5,6 +5,7 @@ from sqlalchemy import exc as sa_exc, util from sqlalchemy.sql import compiler from sqlalchemy.engine import default from sqlalchemy.orm import * +from sqlalchemy.orm import attributes from testlib import * from orm import _base @@ -334,7 +335,16 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "WHERE users.id = addresses.user_id AND addresses.id = :id_1)" ) - self._test(Address.user == User(id=7), ":param_1 = addresses.user_id") + u7 = User(id=7) + attributes.instance_state(u7).commit_all() + + self._test(Address.user == u7, ":param_1 = addresses.user_id") + + self._test(Address.user != u7, "addresses.user_id != :user_id_1 OR addresses.user_id IS NULL") + + self._test(Address.user == None, "addresses.user_id IS NULL") + + self._test(Address.user != None, "addresses.user_id IS NOT NULL") def test_selfref_relation(self): nalias = aliased(Node) diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 5d10c5750..793695919 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -288,7 +288,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults): def tearDown(self): unicode_table.delete().execute() - def testbasic(self): + def test_round_trip(self): assert unicode_table.c.unicode_varchar.type.length == 250 rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') @@ -296,10 +296,6 @@ class UnicodeTest(TestBase, AssertsExecutionResults): unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - print 0, repr(unicodedata) - print 1, repr(x['unicode_varchar']) - print 2, repr(x['unicode_text']) - print 3, repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) if isinstance(x['plain_varchar'], unicode): @@ -310,7 +306,21 @@ class UnicodeTest(TestBase, AssertsExecutionResults): else: self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) - def testassert(self): + def test_union(self): + """ensure compiler processing works for UNIONs""" + + rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' + unicodedata = rawdata.decode('utf-8') + unicode_table.insert().execute(unicode_varchar=unicodedata, + unicode_text=unicodedata, + plain_varchar=rawdata) + + x = union(unicode_table.select(), unicode_table.select()).execute().fetchone() + self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) + self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) + + + def test_assertions(self): try: unicode_table.insert().execute(unicode_varchar='not unicode') assert False @@ -337,11 +347,11 @@ class UnicodeTest(TestBase, AssertsExecutionResults): unicode_engine.dispose() @testing.fails_on('oracle') - def testblanks(self): + def test_blank_strings(self): unicode_table.insert().execute(unicode_varchar=u'') assert select([unicode_table.c.unicode_varchar]).scalar() == u'' - def testengineparam(self): + def test_engine_parameter(self): """tests engine-wide unicode conversion""" prev_unicode = testing.db.engine.dialect.convert_unicode prev_assert = testing.db.engine.dialect.assert_unicode @@ -367,7 +377,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults): @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') @testing.fails_on('firebird') # "Data type unknown" on the parameter - def testlength(self): + def test_length_function(self): """checks the database correctly understands the length of a unicode string""" teststr = u'aaa\x1234' self.assert_(testing.db.func.length(teststr).scalar() == len(teststr)) |