diff options
author | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2010-10-15 08:27:07 +0100 |
---|---|---|
committer | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2010-11-05 09:34:50 +0000 |
commit | 3e658c33b5ee8a75217b9843351e44bf1856e574 (patch) | |
tree | be67b5cfe1e4b41fd1e0280d59b29afb30d2b22f | |
parent | 4f3976681ac9659efc457396f1a079c6e5b4bcc1 (diff) | |
download | psycopg2-3e658c33b5ee8a75217b9843351e44bf1856e574.tar.gz |
Ensure unicode is accepted as type for transaction ids.
We don't do somersaults to ensure people can use snowmen as transaction
ids anyway: it would require passing the connection to xid_ensure and
down below to use the correct encoding.
-rw-r--r-- | psycopg/xid_type.c | 2 | ||||
-rw-r--r-- | tests/test_connection.py | 28 |
2 files changed, 29 insertions, 1 deletions
diff --git a/psycopg/xid_type.c b/psycopg/xid_type.c index 2a21edb..626440c 100644 --- a/psycopg/xid_type.c +++ b/psycopg/xid_type.c @@ -590,7 +590,7 @@ XidObject * xid_from_string(PyObject *str) { XidObject *rv; - if (!PyString_Check(str)) { + if (!(PyString_Check(str) || PyUnicode_Check(str))) { PyErr_SetString(PyExc_TypeError, "not a valid transaction id"); return NULL; } diff --git a/tests/test_connection.py b/tests/test_connection.py index 6d35b2c..b83a90f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -383,6 +383,34 @@ class ConnectionTwoPhaseTests(unittest.TestCase): x2 = Xid.from_string('99_xxx_yyy') self.assertEqual(str(x2), '99_xxx_yyy') + def test_xid_unicode(self): + cnn = self.connect() + x1 = cnn.xid(10, u'uni', u'code') + cnn.tpc_begin(x1) + cnn.tpc_prepare() + cnn.reset() + xid = [ xid for xid in cnn.tpc_recover() + if xid.database == tests.dbname ][0] + self.assertEqual(10, xid.format_id) + self.assertEqual('uni', xid.gtrid) + self.assertEqual('code', xid.bqual) + + def test_xid_unicode_unparsed(self): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check uniconde is accepted as type. + cnn = self.connect() + cnn.set_client_encoding('utf8') + cnn.tpc_begin(u"transaction-id") + cnn.tpc_prepare() + cnn.reset() + + xid = [ xid for xid in cnn.tpc_recover() + if xid.database == tests.dbname ][0] + self.assertEqual(None, xid.format_id) + self.assertEqual('transaction-id', xid.gtrid) + self.assertEqual(None, xid.bqual) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) |