diff options
-rw-r--r-- | lib/sqlalchemy/ext/mutable.py | 6 | ||||
-rw-r--r-- | test/ext/test_mutable.py | 53 |
2 files changed, 56 insertions, 3 deletions
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 7469bcbda..1a4568f23 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -627,10 +627,10 @@ class MutableDict(Mutable, dict): @classmethod def coerce(cls, key, value): - """Convert plain dictionary to MutableDict.""" - if not isinstance(value, MutableDict): + """Convert plain dictionary to instance of this class.""" + if not isinstance(value, cls): if isinstance(value, dict): - return MutableDict(value) + return cls(value) return Mutable.coerce(key, value) else: return value diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index dc0b5ba1c..305eb8c3a 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -332,6 +332,59 @@ class MutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest ) +class CustomMutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest): + + CustomMutableDict = None + + @classmethod + def _type_fixture(cls): + if not(getattr(cls, 'CustomMutableDict')): + MutableDict = super(CustomMutableAssociationScalarJSONTest, cls)._type_fixture() + class CustomMutableDict(MutableDict): + pass + cls.CustomMutableDict = CustomMutableDict + return cls.CustomMutableDict + + @classmethod + def define_tables(cls, metadata): + import json + + class JSONEncodedDict(TypeDecorator): + impl = VARCHAR(50) + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + + CustomMutableDict = cls._type_fixture() + CustomMutableDict.associate_with(JSONEncodedDict) + + Table('foo', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', JSONEncodedDict), + Column('unrelated_data', String(50)) + ) + + def test_pickle_parent(self): + # Picklers don't know how to pickle CustomMutableDict, but we aren't testing that here + pass + + def test_coerce(self): + sess = Session() + f1 = Foo(data={'a': 'b'}) + sess.add(f1) + sess.flush() + eq_(type(f1.data), self._type_fixture()) + + class _CompositeTestBase(object): @classmethod def define_tables(cls, metadata): |