summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/ext/mutable.py6
-rw-r--r--test/ext/test_mutable.py53
2 files changed, 56 insertions, 3 deletions
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py
index 7469bcbda..1a4568f23 100644
--- a/lib/sqlalchemy/ext/mutable.py
+++ b/lib/sqlalchemy/ext/mutable.py
@@ -627,10 +627,10 @@ class MutableDict(Mutable, dict):
@classmethod
def coerce(cls, key, value):
- """Convert plain dictionary to MutableDict."""
- if not isinstance(value, MutableDict):
+ """Convert plain dictionary to instance of this class."""
+ if not isinstance(value, cls):
if isinstance(value, dict):
- return MutableDict(value)
+ return cls(value)
return Mutable.coerce(key, value)
else:
return value
diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py
index dc0b5ba1c..305eb8c3a 100644
--- a/test/ext/test_mutable.py
+++ b/test/ext/test_mutable.py
@@ -332,6 +332,59 @@ class MutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest
)
+class CustomMutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest):
+
+ CustomMutableDict = None
+
+ @classmethod
+ def _type_fixture(cls):
+ if not(getattr(cls, 'CustomMutableDict')):
+ MutableDict = super(CustomMutableAssociationScalarJSONTest, cls)._type_fixture()
+ class CustomMutableDict(MutableDict):
+ pass
+ cls.CustomMutableDict = CustomMutableDict
+ return cls.CustomMutableDict
+
+ @classmethod
+ def define_tables(cls, metadata):
+ import json
+
+ class JSONEncodedDict(TypeDecorator):
+ impl = VARCHAR(50)
+
+ def process_bind_param(self, value, dialect):
+ if value is not None:
+ value = json.dumps(value)
+
+ return value
+
+ def process_result_value(self, value, dialect):
+ if value is not None:
+ value = json.loads(value)
+ return value
+
+ CustomMutableDict = cls._type_fixture()
+ CustomMutableDict.associate_with(JSONEncodedDict)
+
+ Table('foo', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('data', JSONEncodedDict),
+ Column('unrelated_data', String(50))
+ )
+
+ def test_pickle_parent(self):
+ # Picklers don't know how to pickle CustomMutableDict, but we aren't testing that here
+ pass
+
+ def test_coerce(self):
+ sess = Session()
+ f1 = Foo(data={'a': 'b'})
+ sess.add(f1)
+ sess.flush()
+ eq_(type(f1.data), self._type_fixture())
+
+
class _CompositeTestBase(object):
@classmethod
def define_tables(cls, metadata):