diff options
-rw-r--r-- | CHANGES | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/declarative.py | 17 | ||||
-rw-r--r-- | lib/sqlalchemy/util.py | 8 | ||||
-rw-r--r-- | test/ext/declarative.py | 36 |
4 files changed, 63 insertions, 6 deletions
@@ -109,7 +109,7 @@ CHANGES it when reflecting related tables. This is stickier behavior than before which is why it's off by default. -- extensions +- declarative extension - The "synonym" function is now directly usable with "declarative". Pass in the decorated property using the "descriptor" keyword argument, e.g.: somekey = @@ -137,6 +137,12 @@ CHANGES - inheritance in declarative can be disabled when sending "inherits=None" to __mapper_args__. + - declarative_base() takes optional kwarg "mapper", which + is any callable/class/method that produces a mapper, + such as declarative_base(mapper=scopedsession.mapper). + This property can also be set on individual declarative + classes using the "__mapper_cls__" property. + 0.4.4 ------ - sql diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 62691a906..d8576d79b 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -162,6 +162,7 @@ from sqlalchemy.orm import synonym as _orm_synonym, mapper, comparable_property from sqlalchemy.orm.interfaces import MapperProperty from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty from sqlalchemy import util, exceptions +import types __all__ = ['declarative_base', 'synonym_for', 'comparable_using', 'declared_synonym'] @@ -216,8 +217,12 @@ class DeclarativeMeta(type): inherits = cls.__mro__[1] inherits = cls._decl_class_registry.get(inherits.__name__, None) mapper_args['inherits'] = inherits - - cls.__mapper__ = mapper(cls, table, properties=our_stuff, **mapper_args) + + if hasattr(cls, '__mapper_cls__'): + mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) + else: + mapper_cls = mapper + cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args) return type.__init__(cls, classname, bases, dict_) def __setattr__(cls, key, value): @@ -294,13 +299,15 @@ def comparable_using(comparator_factory): return comparable_property(comparator_factory, fn) return decorate -def declarative_base(engine=None, metadata=None): +def declarative_base(engine=None, metadata=None, mapper=None): lcl_metadata = metadata or MetaData() + if engine: + lcl_metadata.bind = engine class Base(object): __metaclass__ = DeclarativeMeta metadata = lcl_metadata - if engine: - metadata.bind = engine + if mapper: + __mapper_cls__ = mapper _decl_class_registry = {} def __init__(self, **kwargs): for k in kwargs: diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 90332fdc0..8451d28b5 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -279,6 +279,14 @@ def get_func_kwargs(func): """Return the full set of legal kwargs for the given `func`.""" return inspect.getargspec(func)[0] +def unbound_method_to_callable(func_or_cls): + """Adjust the incoming callable such that a 'self' argument is not required.""" + + if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self: + return func_or_cls.im_func + else: + return func_or_cls + # from paste.deploy.converters def asbool(obj): if isinstance(obj, (str, unicode)): diff --git a/test/ext/declarative.py b/test/ext/declarative.py index 5da2dded5..c2f49138c 100644 --- a/test/ext/declarative.py +++ b/test/ext/declarative.py @@ -2,6 +2,7 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * +from sqlalchemy.orm.interfaces import MapperExtension from sqlalchemy.ext.declarative import declarative_base, declared_synonym, \ synonym_for, comparable_using from sqlalchemy import exceptions @@ -135,6 +136,41 @@ class DeclarativeTest(TestBase, AssertsExecutionResults): self.assertEquals(a1, Address(email='two')) self.assertEquals(a1.user, User(name='u1')) + + def test_custom_mapper(self): + class MyExt(MapperExtension): + def create_instance(self): + return "CHECK" + + def mymapper(cls, tbl, **kwargs): + kwargs['extension'] = MyExt() + return mapper(cls, tbl, **kwargs) + + from sqlalchemy.orm.mapper import Mapper + class MyMapper(Mapper): + def __init__(self, *args, **kwargs): + kwargs['extension'] = MyExt() + Mapper.__init__(self, *args, **kwargs) + + from sqlalchemy.orm import scoping + ss = scoping.ScopedSession(create_session) + ss.extension = MyExt() + ss_mapper = ss.mapper + + for mapperfunc in (mymapper, MyMapper, ss_mapper): + base = declarative_base() + class Foo(base): + __tablename__ = 'foo' + __mapper_cls__ = mapperfunc + id = Column(Integer, primary_key=True) + assert Foo.__mapper__.compile().extension.create_instance() == 'CHECK' + + base = declarative_base(mapper=mapperfunc) + class Foo(base): + __tablename__ = 'foo' + id = Column(Integer, primary_key=True) + assert Foo.__mapper__.compile().extension.create_instance() == 'CHECK' + @testing.emits_warning('Ignoring declarative-like tuple value of ' 'attribute id') |