summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NEWS2
-rw-r--r--psycopg/cursor_type.c14
-rwxr-xr-xtests/test_copy.py22
3 files changed, 38 insertions, 0 deletions
diff --git a/NEWS b/NEWS
index 3b4a11c..d97c23d 100644
--- a/NEWS
+++ b/NEWS
@@ -8,6 +8,8 @@ What's new in psycopg 2.8.6
(:ticket:`#1101`).
- Fixed search of mxDateTime headers in virtualenvs (:ticket:`#996`).
- Added missing values from errorcodes (:ticket:`#1133`).
+- `cursor.query` reports the query of the last :sql:`COPY` opearation too
+ (:ticket:`#1141`).
- `~psycopg2.errorcodes` map and `~psycopg2.errors` classes updated to
PostgreSQL 13.
- Wheel package compiled against OpenSSL 1.1.1g.
diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c
index f2dd379..c290c71 100644
--- a/psycopg/cursor_type.c
+++ b/psycopg/cursor_type.c
@@ -1446,6 +1446,11 @@ curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs)
Dprintf("curs_copy_from: query = %s", query);
+ Py_CLEAR(self->query);
+ if (!(self->query = Bytes_FromString(query))) {
+ goto exit;
+ }
+
/* This routine stores a borrowed reference. Although it is only held
* for the duration of curs_copy_from, nested invocations of
* Py_BEGIN_ALLOW_THREADS could surrender control to another thread,
@@ -1538,6 +1543,11 @@ curs_copy_to(cursorObject *self, PyObject *args, PyObject *kwargs)
Dprintf("curs_copy_to: query = %s", query);
+ Py_CLEAR(self->query);
+ if (!(self->query = Bytes_FromString(query))) {
+ goto exit;
+ }
+
self->copysize = 0;
Py_INCREF(file);
self->copyfile = file;
@@ -1615,6 +1625,10 @@ curs_copy_expert(cursorObject *self, PyObject *args, PyObject *kwargs)
Py_INCREF(file);
self->copyfile = file;
+ Py_CLEAR(self->query);
+ Py_INCREF(sql);
+ self->query = sql;
+
/* At this point, the SQL statement must be str, not unicode */
if (pq_execute(self, Bytes_AS_STRING(sql), 0, 0, 0) >= 0) {
res = Py_None;
diff --git a/tests/test_copy.py b/tests/test_copy.py
index 05bef21..9274f1d 100755
--- a/tests/test_copy.py
+++ b/tests/test_copy.py
@@ -307,6 +307,28 @@ class CopyTests(ConnectingTestCase):
curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy')
self.assertEqual(curs.rowcount, -1)
+ def test_copy_query(self):
+ curs = self.conn.cursor()
+
+ curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
+ self.assert_(b"copy " in curs.query.lower())
+ self.assert_(b" from stdin" in curs.query.lower())
+
+ curs.copy_expert(
+ "copy tcopy (data) from stdin",
+ StringIO('ddd\neee\n'))
+ self.assert_(b"copy " in curs.query.lower())
+ self.assert_(b" from stdin" in curs.query.lower())
+
+ curs.copy_to(StringIO(), "tcopy")
+ self.assert_(b"copy " in curs.query.lower())
+ self.assert_(b" to stdout" in curs.query.lower())
+
+ curs.execute("insert into tcopy (data) values ('fff')")
+ curs.copy_expert("copy tcopy to stdout", StringIO())
+ self.assert_(b"copy " in curs.query.lower())
+ self.assert_(b" to stdout" in curs.query.lower())
+
@slow
def test_copy_from_segfault(self):
# issue #219