summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2010-10-15 08:27:07 +0100
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2010-11-05 09:34:50 +0000
commit3e658c33b5ee8a75217b9843351e44bf1856e574 (patch)
treebe67b5cfe1e4b41fd1e0280d59b29afb30d2b22f
parent4f3976681ac9659efc457396f1a079c6e5b4bcc1 (diff)
downloadpsycopg2-3e658c33b5ee8a75217b9843351e44bf1856e574.tar.gz
Ensure unicode is accepted as type for transaction ids.
We don't do somersaults to ensure people can use snowmen as transaction ids anyway: it would require passing the connection to xid_ensure and down below to use the correct encoding.
-rw-r--r--psycopg/xid_type.c2
-rw-r--r--tests/test_connection.py28
2 files changed, 29 insertions, 1 deletions
diff --git a/psycopg/xid_type.c b/psycopg/xid_type.c
index 2a21edb..626440c 100644
--- a/psycopg/xid_type.c
+++ b/psycopg/xid_type.c
@@ -590,7 +590,7 @@ XidObject *
xid_from_string(PyObject *str) {
XidObject *rv;
- if (!PyString_Check(str)) {
+ if (!(PyString_Check(str) || PyUnicode_Check(str))) {
PyErr_SetString(PyExc_TypeError, "not a valid transaction id");
return NULL;
}
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 6d35b2c..b83a90f 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -383,6 +383,34 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
x2 = Xid.from_string('99_xxx_yyy')
self.assertEqual(str(x2), '99_xxx_yyy')
+ def test_xid_unicode(self):
+ cnn = self.connect()
+ x1 = cnn.xid(10, u'uni', u'code')
+ cnn.tpc_begin(x1)
+ cnn.tpc_prepare()
+ cnn.reset()
+ xid = [ xid for xid in cnn.tpc_recover()
+ if xid.database == tests.dbname ][0]
+ self.assertEqual(10, xid.format_id)
+ self.assertEqual('uni', xid.gtrid)
+ self.assertEqual('code', xid.bqual)
+
+ def test_xid_unicode_unparsed(self):
+ # We don't expect people shooting snowmen as transaction ids,
+ # so if something explodes in an encode error I don't mind.
+ # Let's just check uniconde is accepted as type.
+ cnn = self.connect()
+ cnn.set_client_encoding('utf8')
+ cnn.tpc_begin(u"transaction-id")
+ cnn.tpc_prepare()
+ cnn.reset()
+
+ xid = [ xid for xid in cnn.tpc_recover()
+ if xid.database == tests.dbname ][0]
+ self.assertEqual(None, xid.format_id)
+ self.assertEqual('transaction-id', xid.gtrid)
+ self.assertEqual(None, xid.bqual)
+
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)