summaryrefslogtreecommitdiff
path: root/simplejson
diff options
context:
space:
mode:
authorBob Ippolito <bob@redivi.com>2017-11-20 10:44:48 -0800
committerGitHub <noreply@github.com>2017-11-20 10:44:48 -0800
commit529268f5098810d6d7f589b63d0e7c3ad47102d7 (patch)
tree93dd3ac580099322234b23460aedeb343d15cecb /simplejson
parent138a2ff3f6db37468f3c7509f69c984792fd49a0 (diff)
parenteb9665a751e8aea21aeedea02ae6617ff53eeaad (diff)
downloadsimplejson-529268f5098810d6d7f589b63d0e7c3ad47102d7.tar.gz
Merge pull request #188 from simplejson/bpo-31505
bpo-31505: Fix an assertion failure in json, in case _json.make_encoder() received a bad encoder() argument.
Diffstat (limited to 'simplejson')
-rw-r--r--simplejson/_speedups.c21
-rw-r--r--simplejson/tests/test_speedups.py24
2 files changed, 42 insertions, 3 deletions
diff --git a/simplejson/_speedups.c b/simplejson/_speedups.c
index bfd053a..f463c04 100644
--- a/simplejson/_speedups.c
+++ b/simplejson/_speedups.c
@@ -2811,10 +2811,25 @@ static PyObject *
encoder_encode_string(PyEncoderObject *s, PyObject *obj)
{
/* Return the JSON representation of a string */
- if (s->fast_encode)
+ PyObject *encoded;
+
+ if (s->fast_encode) {
return py_encode_basestring_ascii(NULL, obj);
- else
- return PyObject_CallFunctionObjArgs(s->encoder, obj, NULL);
+ }
+ encoded = PyObject_CallFunctionObjArgs(s->encoder, obj, NULL);
+ if (encoded != NULL &&
+#if PY_MAJOR_VERSION < 3
+ !JSON_ASCII_Check(unicode) &&
+#endif /* PY_MAJOR_VERSION < 3 */
+ !PyUnicode_Check(encoded))
+ {
+ PyErr_Format(PyExc_TypeError,
+ "encoder() must return a string, not %.80s",
+ Py_TYPE(encoded)->tp_name);
+ Py_DECREF(encoded);
+ return NULL;
+ }
+ return encoded;
}
static int
diff --git a/simplejson/tests/test_speedups.py b/simplejson/tests/test_speedups.py
index b59eeca..8b146df 100644
--- a/simplejson/tests/test_speedups.py
+++ b/simplejson/tests/test_speedups.py
@@ -60,6 +60,30 @@ class TestEncode(TestCase):
)
@skip_if_speedups_missing
+ def test_bad_str_encoder(self):
+ # Issue #31505: There shouldn't be an assertion failure in case
+ # c_make_encoder() receives a bad encoder() argument.
+ import decimal
+ def bad_encoder1(*args):
+ return None
+ enc = encoder.c_make_encoder(
+ None, lambda obj: str(obj),
+ bad_encoder1, None, ': ', ', ',
+ False, False, False, {}, False, False, False,
+ None, None, 'utf-8', False, False, decimal.Decimal, False)
+ self.assertRaises(TypeError, enc, 'spam', 4)
+ self.assertRaises(TypeError, enc, {'spam': 42}, 4)
+
+ def bad_encoder2(*args):
+ 1/0
+ enc = encoder.c_make_encoder(
+ None, lambda obj: str(obj),
+ bad_encoder2, None, ': ', ', ',
+ False, False, False, {}, False, False, False,
+ None, None, 'utf-8', False, False, decimal.Decimal, False)
+ self.assertRaises(ZeroDivisionError, enc, 'spam', 4)
+
+ @skip_if_speedups_missing
def test_bad_bool_args(self):
def test(name):
encoder.JSONEncoder(**{name: BadBool()}).encode({})