diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-11-20 12:55:46 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-11-20 12:55:46 -0500 |
commit | c5442b3adb56cae15607aad32b4219ac11ae809e (patch) | |
tree | 0502f97a14672dd089adab6184681cb7a0f3bfea | |
parent | d505ea71aed44ecae718052131dc0a2fb2c9dd99 (diff) | |
parent | 6dbf2c3314a797a39624f1e68569bfbbb2b6ac87 (diff) | |
download | sqlalchemy-c5442b3adb56cae15607aad32b4219ac11ae809e.tar.gz |
- merge hybrid attributes branch, [ticket:1903]
-rw-r--r-- | examples/derived_attributes/__init__.py | 10 | ||||
-rw-r--r-- | examples/derived_attributes/attributes.py | 168 | ||||
-rwxr-xr-x | lib/sqlalchemy/ext/declarative.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/hybrid.py | 138 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 69 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/dynamic.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 41 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 173 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 22 | ||||
-rw-r--r-- | test/ext/test_hybrid.py | 102 | ||||
-rw-r--r-- | test/orm/test_mapper.py | 23 | ||||
-rw-r--r-- | test/orm/test_query.py | 73 |
15 files changed, 465 insertions, 379 deletions
diff --git a/examples/derived_attributes/__init__.py b/examples/derived_attributes/__init__.py deleted file mode 100644 index 98c946fca..000000000 --- a/examples/derived_attributes/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Illustrates a clever technique using Python descriptors to create custom attributes representing SQL expressions when used at the class level, and Python expressions when used at the instance level. In some cases this technique replaces the need to configure the attribute in the mapping, instead relying upon ordinary Python behavior to create custom expression components. - -E.g.:: - - class BaseInterval(object): - @hybrid - def contains(self,point): - return (self.start <= point) & (point < self.end) - -""" diff --git a/examples/derived_attributes/attributes.py b/examples/derived_attributes/attributes.py deleted file mode 100644 index f36cbd541..000000000 --- a/examples/derived_attributes/attributes.py +++ /dev/null @@ -1,168 +0,0 @@ -from functools import update_wrapper -import new - -class method(object): - def __init__(self, func, expr=None): - self.func = func - self.expr = expr or func - - def __get__(self, instance, owner): - if instance is None: - return new.instancemethod(self.expr, owner, owner.__class__) - else: - return new.instancemethod(self.func, instance, owner) - - def expression(self, expr): - self.expr = expr - return self - -class property_(object): - def __init__(self, fget, fset=None, fdel=None, expr=None): - self.fget = fget - self.fset = fset - self.fdel = fdel - self.expr = expr or fget - update_wrapper(self, fget) - - def __get__(self, instance, owner): - if instance is None: - return self.expr(owner) - else: - return self.fget(instance) - - def __set__(self, instance, value): - self.fset(instance, value) - - def __delete__(self, instance): - self.fdel(instance) - - def setter(self, fset): - self.fset = fset - return self - - def deleter(self, fdel): - self.fdel = fdel - return self - - def expression(self, expr): - self.expr = expr - return self - -### Example code - -from sqlalchemy import Table, Column, Integer, create_engine, func -from sqlalchemy.orm import sessionmaker, aliased -from sqlalchemy.ext.declarative import declarative_base - -Base = declarative_base() - -class BaseInterval(object): - @method - def contains(self,point): - """Return true if the interval contains the given interval.""" - - return (self.start <= point) & (point < self.end) - - @method - def intersects(self, other): - """Return true if the interval intersects the given interval.""" - - return (self.start < other.end) & (self.end > other.start) - - @method - def _max(self, x, y): - """Return the max of two values.""" - - return max(x, y) - - @_max.expression - def _max(cls, x, y): - """Return the SQL max of two values.""" - - return func.max(x, y) - - @method - def max_length(self, other): - """Return the longer length of this interval and another.""" - - return self._max(self.length, other.length) - - def __repr__(self): - return "%s(%s..%s)" % (self.__class__.__name__, self.start, self.end) - -class Interval1(BaseInterval, Base): - """Interval stored as endpoints""" - - __table__ = Table('interval1', Base.metadata, - Column('id', Integer, primary_key=True), - Column('start', Integer, nullable=False), - Column('end', Integer, nullable=False) - ) - - def __init__(self, start, end): - self.start = start - self.end = end - - @property_ - def length(self): - return self.end - self.start - -class Interval2(BaseInterval, Base): - """Interval stored as start and length""" - - __table__ = Table('interval2', Base.metadata, - Column('id', Integer, primary_key=True), - Column('start', Integer, nullable=False), - Column('length', Integer, nullable=False) - ) - - def __init__(self, start, length): - self.start = start - self.length = length - - @property_ - def end(self): - return self.start + self.length - - - -engine = create_engine('sqlite://', echo=True) - -Base.metadata.create_all(engine) - -session = sessionmaker(engine)() - -intervals = [Interval1(1,4), Interval1(3,15), Interval1(11,16)] - -for interval in intervals: - session.add(interval) - session.add(Interval2(interval.start, interval.length)) - -session.commit() - -for Interval in (Interval1, Interval2): - print "Querying using interval class %s" % Interval.__name__ - - print - print '-- length less than 10' - print [(i, i.length) for i in - session.query(Interval).filter(Interval.length < 10).all()] - - print - print '-- contains 12' - print session.query(Interval).filter(Interval.contains(12)).all() - - print - print '-- intersects 2..10' - other = Interval1(2,10) - result = session.query(Interval).\ - filter(Interval.intersects(other)).\ - order_by(Interval.length).all() - print [(interval, interval.intersects(other)) for interval in result] - - print - print '-- longer length' - interval_alias = aliased(Interval) - print session.query(Interval.length, - interval_alias.length, - Interval.max_length(interval_alias)).all() diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 3c6cab59a..8381e5ee1 100755 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -1265,12 +1265,13 @@ class _GetColumns(object): mapper = class_mapper(self.cls, compile=False) if mapper: - prop = mapper.get_property(key, raiseerr=False) - if prop is None: + if not mapper.has_property(key): raise exceptions.InvalidRequestError( "Class %r does not have a mapped column named %r" % (self.cls, key)) - elif not isinstance(prop, ColumnProperty): + + prop = mapper.get_property(key) + if not isinstance(prop, ColumnProperty): raise exceptions.InvalidRequestError( "Property %r is not an instance of" " ColumnProperty (i.e. does not correspond" diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py new file mode 100644 index 000000000..5bb158413 --- /dev/null +++ b/lib/sqlalchemy/ext/hybrid.py @@ -0,0 +1,138 @@ +"""Define attributes on ORM-mapped classes that have 'hybrid' behavior. + +'hybrid' means the attribute has distinct behaviors defined at the +class level and at the instance level. + +Consider a table `interval` as below:: + + from sqlalchemy import MetaData, Table, Column, Integer + from sqlalchemy.orm import mapper, create_session + + engine = create_engine('sqlite://') + metadata = MetaData() + + interval_table = Table('interval', metadata, + Column('id', Integer, primary_key=True), + Column('start', Integer, nullable=False), + Column('end', Integer, nullable=False)) + metadata.create_all(engine) + +We can define higher level functions on mapped classes that produce SQL +expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :func:`hybrid.method` +or :func:`hybrid.property` may receive ``self`` as an instance of the class, +or as the class itself:: + + # A base class for intervals + + from sqlalchemy.orm import hybrid + + class Interval(object): + def __init__(self, start, end): + self.start = start + self.end = end + + @hybrid.property + def length(self): + return self.end - self.start + + @hybrid.method + def contains(self,point): + return (self.start <= point) & (point < self.end) + + @hybrid.method + def intersects(self, other): + return (self.start < other.end) & (self.end > other.start) + + mapper(Interval1, interval_table1) + + session = sessionmaker(engine)() + + session.add_all( + [Interval1(1,4), Interval1(3,15), Interval1(11,16)] + ) + intervals = + + for interval in intervals: + session.add(interval) + session.add(Interval2(interval.start, interval.length)) + + session.commit() + + ### TODO ADD EXAMPLES HERE AND STUFF THIS ISN'T FINISHED ### + +""" +from sqlalchemy import util +from sqlalchemy.orm import attributes, interfaces + +class method(object): + def __init__(self, func, expr=None): + self.func = func + self.expr = expr or func + + def __get__(self, instance, owner): + if instance is None: + return new.instancemethod(self.expr, owner, owner.__class__) + else: + return new.instancemethod(self.func, instance, owner) + + def expression(self, expr): + self.expr = expr + return self + +class property_(object): + def __init__(self, fget, fset=None, fdel=None, expr=None): + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = expr or fget + util.update_wrapper(self, fget) + + def __get__(self, instance, owner): + if instance is None: + return self.expr(owner) + else: + return self.fget(instance) + + def __set__(self, instance, value): + self.fset(instance, value) + + def __delete__(self, instance): + self.fdel(instance) + + def setter(self, fset): + self.fset = fset + return self + + def deleter(self, fdel): + self.fdel = fdel + return self + + def expression(self, expr): + self.expr = expr + return self + + def comparator(self, comparator): + proxy_attr = attributes.\ + create_proxied_attribute(self) + def expr(owner): + return proxy_attr(self.__name__, self, comparator(owner)) + self.expr = expr + return self + + +class Comparator(interfaces.PropComparator): + def __init__(self, expression): + self.expression = expression + + def __clause_element__(self): + expr = self.expression + while hasattr(expr, '__clause_element__'): + expr = expr.__clause_element__() + return expr + + def adapted(self, adapter): + # interesting.... + return self + + diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 6872dd645..86f950813 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -130,52 +130,38 @@ class InstrumentedAttribute(QueryableAttribute): return self.impl.get(instance_state(instance), instance_dict(instance)) -class _ProxyImpl(object): - accepts_scalar_loader = False - expire_missing = True - - def __init__(self, key): - self.key = key +def create_proxied_attribute(descriptor): + """Create an QueryableAttribute / user descriptor hybrid. -def proxied_attribute_factory(descriptor): - """Create an InstrumentedAttribute / user descriptor hybrid. - - Returns a new InstrumentedAttribute type that delegates descriptor + Returns a new QueryableAttribute type that delegates descriptor behavior and getattr() to the given descriptor. """ - class Proxy(InstrumentedAttribute): + class Proxy(QueryableAttribute): """A combination of InsturmentedAttribute and a regular descriptor.""" - def __init__(self, key, descriptor, comparator, parententity): + def __init__(self, key, descriptor, comparator, adapter=None): self.key = key - # maintain ProxiedAttribute.user_prop compatability. - self.descriptor = self.user_prop = descriptor + self.descriptor = descriptor self._comparator = comparator - self._parententity = parententity - self.impl = _ProxyImpl(key) - + self.adapter = adapter + @util.memoized_property def comparator(self): if util.callable(self._comparator): self._comparator = self._comparator() + if self.adapter: + self._comparator = self._comparator.adapted(self.adapter) return self._comparator + + def adapted(self, adapter): + return self.__class__(self.key, self.descriptor, + self._comparator, + adapter) - def __get__(self, instance, owner): - """Delegate __get__ to the original descriptor.""" - if instance is None: - descriptor.__get__(instance, owner) - return self - return descriptor.__get__(instance, owner) - - def __set__(self, instance, value): - """Delegate __set__ to the original descriptor.""" - return descriptor.__set__(instance, value) - - def __delete__(self, instance): - """Delegate __delete__ to the original descriptor.""" - return descriptor.__delete__(instance) - + def __str__(self): + return self.key + def __getattr__(self, attribute): """Delegate __getattr__ to the original descriptor and/or comparator.""" @@ -184,12 +170,12 @@ def proxied_attribute_factory(descriptor): return getattr(descriptor, attribute) except AttributeError: try: - return getattr(self._comparator, attribute) + return getattr(self.comparator, attribute) except AttributeError: raise AttributeError( 'Neither %r object nor %r object has an attribute %r' % ( type(descriptor).__name__, - type(self._comparator).__name__, + type(self.comparator).__name__, attribute) ) @@ -1030,15 +1016,12 @@ def has_parent(cls, obj, key, optimistic=False): return manager.has_parent(state, key, optimistic) def register_attribute(class_, key, **kw): - proxy_property = kw.pop('proxy_property', None) - comparator = kw.pop('comparator', None) parententity = kw.pop('parententity', None) doc = kw.pop('doc', None) - register_descriptor(class_, key, proxy_property, + register_descriptor(class_, key, comparator, parententity, doc=doc) - if not proxy_property: - register_attribute_impl(class_, key, **kw) + register_attribute_impl(class_, key, **kw) def register_attribute_impl(class_, key, uselist=False, callable_=None, @@ -1074,15 +1057,11 @@ def register_attribute_impl(class_, key, manager.post_configure_attribute(key) -def register_descriptor(class_, key, proxy_property=None, comparator=None, +def register_descriptor(class_, key, comparator=None, parententity=None, property_=None, doc=None): manager = manager_of_class(class_) - if proxy_property: - proxy_type = proxied_attribute_factory(proxy_property) - descriptor = proxy_type(key, proxy_property, comparator, parententity) - else: - descriptor = InstrumentedAttribute(class_, key, comparator=comparator, + descriptor = InstrumentedAttribute(class_, key, comparator=comparator, parententity=parententity) descriptor.__doc__ = doc diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index caa057717..95a58ee84 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -193,7 +193,7 @@ class AppenderMixin(object): self.attr = attr mapper = object_mapper(instance) - prop = mapper.get_property(self.attr.key, resolve_synonyms=True) + prop = mapper.get_property(self.attr.key) self._criterion = prop.compare( operators.eq, instance, diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6bcdc6f0f..5686bdd33 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -469,8 +469,10 @@ class PropertyOption(MapperOption): path_element = entity.path_entity mapper = entity.mapper mappers.append(mapper) - prop = mapper.get_property(token, - resolve_synonyms=True, raiseerr=raiseerr) + if mapper.has_property(token): + prop = mapper.get_property(token) + else: + prop = None key = token elif isinstance(token, PropComparator): prop = token.property diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index accb7c4be..a31064858 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -895,36 +895,22 @@ class Mapper(object): def has_property(self, key): return key in self._props - def get_property(self, key, - resolve_synonyms=False, - raiseerr=True, _compile_mappers=True): - - """return a :class:`.MapperProperty` associated with the given key. + def get_property(self, key, _compile_mappers=True): + """return a MapperProperty associated with the given key. - resolve_synonyms=False and raiseerr=False are deprecated. + Calls getattr() against the mapped class itself, so that class-level + proxies will be resolved to the underlying property, if any. """ if _compile_mappers and not self.compiled: self.compile() - - if not resolve_synonyms: - prop = self._props.get(key, None) - if prop is None and raiseerr: - raise sa_exc.InvalidRequestError( - "Mapper '%s' has no property '%s'" % - (self, key)) - return prop - else: - try: - return getattr(self.class_, key).property - except AttributeError: - if raiseerr: - raise sa_exc.InvalidRequestError( - "Mapper '%s' has no property '%s'" % (self, key)) - else: - return None - + try: + return getattr(self.class_, key).property + except AttributeError: + raise sa_exc.InvalidRequestError( + "Mapper '%s' has no property '%s'" % (self, key)) + @util.deprecated('0.6.4', 'Call to deprecated function mapper._get_col_to_pr' 'op(). Use mapper.get_property_by_column()') @@ -1125,8 +1111,11 @@ class Mapper(object): def _is_userland_descriptor(self, obj): return not isinstance(obj, - (MapperProperty, attributes.InstrumentedAttribute)) and \ - hasattr(obj, '__get__') + (MapperProperty, attributes.QueryableAttribute)) and \ + hasattr(obj, '__get__') and not \ + isinstance(obj.__get__(None, obj), + attributes.QueryableAttribute) + def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index edfb861f4..02e883de4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -252,7 +252,63 @@ class CompositeProperty(ColumnProperty): class DescriptorProperty(MapperProperty): """:class:`MapperProperty` which proxies access to a - plain descriptor.""" + user-defined descriptor.""" + + def instrument_class(self, mapper): + from sqlalchemy.ext import hybrid + + # hackety hack hack + class _ProxyImpl(object): + accepts_scalar_loader = False + expire_missing = True + + def __init__(self, key): + self.key = key + + if self.descriptor is None: + desc = getattr(mapper.class_, self.key, None) + if mapper._is_userland_descriptor(desc): + self.descriptor = desc + + if self.descriptor is None: + def fset(obj, value): + setattr(obj, self.name, value) + def fdel(obj): + delattr(obj, self.name) + def fget(obj): + return getattr(obj, self.name) + fget.__doc__ = self.doc + + descriptor = hybrid.property_( + fget=fget, + fset=fset, + fdel=fdel, + ) + elif isinstance(self.descriptor, property): + descriptor = hybrid.property_( + fget=self.descriptor.fget, + fset=self.descriptor.fset, + fdel=self.descriptor.fdel, + ) + else: + descriptor = hybrid.property_( + fget=self.descriptor.__get__, + fset=self.descriptor.__set__, + fdel=self.descriptor.__delete__, + ) + + proxy_attr = attributes.\ + create_proxied_attribute(self.descriptor or descriptor)\ + ( + self.key, + self.descriptor or descriptor, + lambda: self._comparator_factory(mapper) + ) + def get_comparator(owner): + return util.update_wrapper(proxy_attr, descriptor) + descriptor.expr = get_comparator + descriptor.impl = _ProxyImpl(self.key) + mapper.class_manager.instrument_attribute(self.key, descriptor) def setup(self, context, entity, path, adapter, **kwargs): pass @@ -264,14 +320,13 @@ class DescriptorProperty(MapperProperty): dest_state, dest_dict, load, _recursive): pass - class ConcreteInheritedProperty(DescriptorProperty): """A 'do nothing' :class:`MapperProperty` that disables an attribute on a concrete subclass that is only present on the inherited mapper, not the concrete classes' mapper. - + Cases where this occurs include: - + * When the superclass mapper is mapped against a "polymorphic union", which includes all attributes from all subclasses. @@ -279,10 +334,20 @@ class ConcreteInheritedProperty(DescriptorProperty): but not on the subclass mapper. Concrete mappers require that relationship() is configured explicitly on each subclass. - + """ + + def _comparator_factory(self, mapper): + comparator_callable = None + + for m in self.parent.iterate_to_root(): + p = m._props[self.key] + if not isinstance(p, ConcreteInheritedProperty): + comparator_callable = p.comparator_factory + break + return comparator_callable - def instrument_class(self, mapper): + def __init__(self): def warn(): raise AttributeError("Concrete %s does not implement " "attribute %r at the instance level. Add this " @@ -295,26 +360,12 @@ class ConcreteInheritedProperty(DescriptorProperty): def __delete__(s, obj): warn() def __get__(s, obj, owner): + if obj is None: + return self.descriptor warn() - - comparator_callable = None - # TODO: put this process into a deferred callable? - for m in self.parent.iterate_to_root(): - p = m.get_property(self.key, _compile_mappers=False) - if not isinstance(p, ConcreteInheritedProperty): - comparator_callable = p.comparator_factory - break - - attributes.register_descriptor( - mapper.class_, - self.key, - comparator=comparator_callable(self, mapper), - parententity=mapper, - property_=self, - proxy_property=NoninheritedConcreteProp() - ) - - + self.descriptor = NoninheritedConcreteProp() + + class SynonymProperty(DescriptorProperty): def __init__(self, name, map_column=None, @@ -327,6 +378,15 @@ class SynonymProperty(DescriptorProperty): self.doc = doc or (descriptor and descriptor.__doc__) or None util.set_creation_order(self) + def _comparator_factory(self, mapper): + prop = getattr(mapper.class_, self.name).property + + if self.comparator_factory: + comp = self.comparator_factory(prop, mapper) + else: + comp = prop.comparator_factory(prop, mapper) + return comp + def set_parent(self, parent, init): if self.map_column: # implement the 'map_column' option. @@ -352,50 +412,8 @@ class SynonymProperty(DescriptorProperty): init=init, setparent=True) p._mapped_by_synonym = self.key - + self.parent = parent - - def instrument_class(self, mapper): - - if self.descriptor is None: - desc = getattr(mapper.class_, self.key, None) - if mapper._is_userland_descriptor(desc): - self.descriptor = desc - - if self.descriptor is None: - class SynonymProp(object): - def __set__(s, obj, value): - setattr(obj, self.name, value) - def __delete__(s, obj): - delattr(obj, self.name) - def __get__(s, obj, owner): - if obj is None: - return s - return getattr(obj, self.name) - - self.descriptor = SynonymProp() - - def comparator_callable(prop, mapper): - def comparator(): - prop = mapper.get_property( - self.name, resolve_synonyms=True, - _compile_mappers=False) - if self.comparator_factory: - return self.comparator_factory(prop, mapper) - else: - return prop.comparator_factory(prop, mapper) - return comparator - - attributes.register_descriptor( - mapper.class_, - self.key, - comparator=comparator_callable(self, mapper), - parententity=mapper, - property_=self, - proxy_property=self.descriptor, - doc=self.doc - ) - class ComparableProperty(DescriptorProperty): """Instruments a Python property for use in query expressions.""" @@ -406,23 +424,8 @@ class ComparableProperty(DescriptorProperty): self.doc = doc or (descriptor and descriptor.__doc__) or None util.set_creation_order(self) - def instrument_class(self, mapper): - """Set up a proxy to the unmanaged descriptor.""" - - if self.descriptor is None: - desc = getattr(mapper.class_, self.key, None) - if mapper._is_userland_descriptor(desc): - self.descriptor = desc - - attributes.register_descriptor( - mapper.class_, - self.key, - comparator=self.comparator_factory(self, mapper), - parententity=mapper, - property_=self, - proxy_property=self.descriptor, - doc=self.doc, - ) + def _comparator_factory(self, mapper): + return self.comparator_factory(self, mapper) class RelationshipProperty(StrategizedProperty): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a56b67546..87cce96d3 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -951,7 +951,6 @@ class Query(object): clauses = [_entity_descriptor(self._joinpoint_zero(), key) == value for key, value in kwargs.iteritems()] - return self.filter(sql.and_(*clauses)) @_generative(_no_statement_condition, _no_limit_offset) @@ -2683,7 +2682,10 @@ class _ColumnEntity(_QueryEntity): if isinstance(column, basestring): column = sql.literal_column(column) self._label_name = column.name - elif isinstance(column, attributes.QueryableAttribute): + elif isinstance(column, ( + attributes.QueryableAttribute, + interfaces.PropComparator + )): self._label_name = column.key column = column.__clause_element__() else: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 578ef2de1..1c316c3c0 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1287,7 +1287,7 @@ class LoadEagerFromAliasOption(PropertyOption): if isinstance(self.alias, basestring): mapper = mappers[-1] (root_mapper, propname) = paths[-1][-2:] - prop = mapper.get_property(propname, resolve_synonyms=True) + prop = mapper.get_property(propname) self.alias = prop.target.alias(self.alias) query._attributes[ ("user_defined_eager_row_processor", @@ -1296,7 +1296,7 @@ class LoadEagerFromAliasOption(PropertyOption): else: (root_mapper, propname) = paths[-1][-2:] mapper = mappers[-1] - prop = mapper.get_property(propname, resolve_synonyms=True) + prop = mapper.get_property(propname) adapter = query._polymorphic_adapters.get(prop.mapper, None) query._attributes[ ("user_defined_eager_row_processor", diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 9447eed30..0e0d6f568 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -275,13 +275,6 @@ class AliasedClass(object): return queryattr def __getattr__(self, key): - if self.__mapper.has_property(key): - return self.__adapt_prop( - self.__mapper.get_property( - key, _compile_mappers=False - ) - ) - for base in self.__target.__mro__: try: attr = object.__getattribute__(base, key) @@ -291,15 +284,20 @@ class AliasedClass(object): break else: raise AttributeError(key) - - if hasattr(attr, 'func_code'): + + if isinstance(attr, attributes.QueryableAttribute): + return self.__adapt_prop(attr.property) + elif hasattr(attr, 'func_code'): is_method = getattr(self.__target, key, None) if is_method and is_method.im_self is not None: return util.types.MethodType(attr.im_func, self, self) else: return None elif hasattr(attr, '__get__'): - return attr.__get__(None, self) + ret = attr.__get__(None, self) + if isinstance(ret, PropComparator): + return ret.adapted(self.__adapt_element) + return ret else: return attr @@ -437,7 +435,7 @@ def with_parent(instance, prop): """ if isinstance(prop, basestring): mapper = object_mapper(instance) - prop = mapper.get_property(prop, resolve_synonyms=True) + prop = mapper.get_property(prop) elif isinstance(prop, attributes.QueryableAttribute): prop = prop.property @@ -486,7 +484,7 @@ def _entity_descriptor(entity, key): """ if not isinstance(entity, (AliasedClass, type)): entity = entity.class_ - + try: return getattr(entity, key) except AttributeError: diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py new file mode 100644 index 000000000..3dfd4c856 --- /dev/null +++ b/test/ext/test_hybrid.py @@ -0,0 +1,102 @@ +""" + +tests for sqlalchemy.ext.hybrid TODO + + +""" + + +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext import hybrid +from sqlalchemy.orm.interfaces import PropComparator + + +""" +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext import hybrid + +Base = declarative_base() + + +class UCComparator(hybrid.Comparator): + + def __eq__(self, other): + if other is None: + return self.expression == None + else: + return func.upper(self.expression) == func.upper(other) + +class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + _value = Column("value", String) + + @hybrid.property_ + def value(self): + return int(self._value) + + @value.comparator + def value(cls): + return UCComparator(cls._value) + + @value.setter + def value(self, v): + self.value = v +print aliased(A).value +print aliased(A).__tablename__ + +sess = create_session() + +print A.value == "foo" +print sess.query(A.value) +print sess.query(aliased(A).value) +print sess.query(aliased(A)).filter_by(value="foo") +""" + +""" +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext import hybrid + +Base = declarative_base() + +class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + _value = Column("value", String) + + @hybrid.property + def value(self): + return int(self._value) + + @value.expression + def value(cls): + return func.foo(cls._value) + cls.bar_value + + @value.setter + def value(self, v): + self.value = v + + @hybrid.property + def bar_value(cls): + return func.bar(cls._value) + +#print A.value +#print A.value.__doc__ + +print aliased(A).value +print aliased(A).__tablename__ + +sess = create_session() + +print sess.query(A).filter_by(value="foo") + +print sess.query(aliased(A)).filter_by(value="foo") + + +"""
\ No newline at end of file diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 2ec228a0f..acdf9c718 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -907,7 +907,6 @@ class MapperTest(_fixtures.FixtureTest): args = (UCComparator, User.uc_name) else: args = (UCComparator,) - mapper(User, users, properties=dict( uc_name = sa.orm.comparable_property(*args))) return User @@ -1180,7 +1179,7 @@ class DocumentTest(testing.TestBase): backref=backref('foo',doc='foo relationship') ), 'foober':column_property(t1.c.col3, doc='alternate data col'), - 'hoho':synonym(t1.c.col4, doc="syn of col4") + 'hoho':synonym("col4", doc="syn of col4") }) mapper(Bar, t2) compile_mappers() @@ -1554,10 +1553,22 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): class MyFactory(ColumnProperty.Comparator): __hash__ = None def __eq__(self, other): - return func.foobar(self.__clause_element__()) == func.foobar(other) - mapper(User, users, properties={'name':synonym('_name', map_column=True, comparator_factory=MyFactory)}) - self.assert_compile(User.name == 'ed', "foobar(users.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - self.assert_compile(aliased(User).name == 'ed', "foobar(users_1.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + return func.foobar(self.__clause_element__()) ==\ + func.foobar(other) + + mapper(User, users, properties={ + 'name':synonym('_name', map_column=True, + comparator_factory=MyFactory) + }) + self.assert_compile( + User.name == 'ed', + "foobar(users.name) = foobar(:foobar_1)", + dialect=default.DefaultDialect()) + + self.assert_compile( + aliased(User).name == 'ed', + "foobar(users_1.name) = foobar(:foobar_1)", + dialect=default.DefaultDialect()) @testing.resolve_artifact_names def test_relationship(self): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index d96fa7384..bc2c5f323 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -3689,53 +3689,92 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): 'orders':relationship(Order, backref='user'), # o2m, m2o }) mapper(Order, orders, properties={ - 'items':relationship(Item, secondary=order_items, order_by=items.c.id), #m2m + 'items':relationship(Item, secondary=order_items, + order_by=items.c.id), #m2m }) mapper(Item, items, properties={ - 'keywords':relationship(Keyword, secondary=item_keywords, order_by=keywords.c.id) #m2m + 'keywords':relationship(Keyword, secondary=item_keywords, + order_by=keywords.c.id) #m2m }) mapper(Keyword, keywords) - sel = users.select(users.c.id.in_([7, 8])) sess = create_session() - - eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + sel = users.select(users.c.id.in_([7, 8])) + + eq_(sess.query(User).select_from(sel).\ + join('orders', 'items', 'keywords').\ + filter(Keyword.name.in_(['red', 'big', 'round'])).\ + all(), + [ User(name=u'jack',id=7) ]) - eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + eq_(sess.query(User).select_from(sel).\ + join('orders', 'items', 'keywords', aliased=True).\ + filter(Keyword.name.in_(['red', 'big', 'round'])).\ + all(), + [ User(name=u'jack',id=7) ]) def go(): eq_( sess.query(User).select_from(sel). - options(joinedload_all('orders.items.keywords')). - join('orders', 'items', 'keywords', aliased=True). - filter(Keyword.name.in_(['red', 'big', 'round'])).all(), + options(joinedload_all('orders.items.keywords')). + join('orders', 'items', 'keywords', aliased=True). + filter(Keyword.name.in_(['red', 'big', 'round'])).\ + all(), [ User(name=u'jack',orders=[ Order(description=u'order 1',items=[ - Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]), - Item(description=u'item 2',keywords=[Keyword(name=u'red',id=2), Keyword(name=u'small',id=5), Keyword(name=u'square')]), - Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]) + Item(description=u'item 1', + keywords=[ + Keyword(name=u'red'), + Keyword(name=u'big'), + Keyword(name=u'round') + ]), + Item(description=u'item 2', + keywords=[ + Keyword(name=u'red',id=2), + Keyword(name=u'small',id=5), + Keyword(name=u'square') + ]), + Item(description=u'item 3', + keywords=[ + Keyword(name=u'green',id=3), + Keyword(name=u'big',id=4), + Keyword(name=u'round',id=6)]) ]), Order(description=u'order 3',items=[ - Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]), + Item(description=u'item 3', + keywords=[ + Keyword(name=u'green',id=3), + Keyword(name=u'big',id=4), + Keyword(name=u'round',id=6) + ]), Item(description=u'item 4',keywords=[],id=4), Item(description=u'item 5',keywords=[],id=5) ]), - Order(description=u'order 5',items=[Item(description=u'item 5',keywords=[])])]) + Order(description=u'order 5', + items=[ + Item(description=u'item 5',keywords=[])]) + ]) ]) self.assert_sql_count(testing.db, go, 1) - + sess.expunge_all() sel2 = orders.select(orders.c.id.in_([1,2,3])) - eq_(sess.query(Order).select_from(sel2).join('items', 'keywords').filter(Keyword.name == 'red').order_by(Order.id).all(), [ + eq_(sess.query(Order).select_from(sel2).\ + join('items', 'keywords').\ + filter(Keyword.name == 'red').\ + order_by(Order.id).all(), [ Order(description=u'order 1',id=1), Order(description=u'order 2',id=2), ]) - eq_(sess.query(Order).select_from(sel2).join('items', 'keywords', aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [ + eq_(sess.query(Order).select_from(sel2).\ + join('items', 'keywords', aliased=True).\ + filter(Keyword.name == 'red').\ + order_by(Order.id).all(), [ Order(description=u'order 1',id=1), Order(description=u'order 2',id=2), ]) |