summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-02 17:29:17 +0000
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-02 17:29:17 +0000
commitdc1b4fff9001964c719e3f4471cc5a6fe6533e3a (patch)
treef3ce470b63ba65b21c963d8b6d47c87aa5b22cf4
parentd2fdc5ca9f6d5ac76ee39fc6b7db626345a6c84c (diff)
downloadpsycopg2-dc1b4fff9001964c719e3f4471cc5a6fe6533e3a.tar.gz
Avoid an useless encode/decode roundtrip in execute_values()
Tests moved into a separate module.
-rw-r--r--lib/extras.py48
-rwxr-xr-xtests/__init__.py2
-rwxr-xr-xtests/test_fast_executemany.py237
-rwxr-xr-xtests/test_types_extras.py178
4 files changed, 283 insertions, 182 deletions
diff --git a/lib/extras.py b/lib/extras.py
index 1aad3d1..80034e6 100644
--- a/lib/extras.py
+++ b/lib/extras.py
@@ -1232,10 +1232,50 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
[(1, 20, 3), (4, 50, 6), (7, 8, 9)])
'''
+ # we can't just use sql % vals because vals is bytes: if sql is bytes
+ # there will be some decoding error because of stupid codec used, and Py3
+ # doesn't implement % on bytes.
+ if not isinstance(sql, bytes):
+ sql = sql.encode(_ext.encodings[cur.connection.encoding])
+ pre, post = _split_sql(sql)
+
for page in _paginate(argslist, page_size=page_size):
if template is None:
template = '(%s)' % ','.join(['%s'] * len(page[0]))
- values = b",".join(cur.mogrify(template, args) for args in page)
- if isinstance(values, bytes):
- values = values.decode(_ext.encodings[cur.connection.encoding])
- cur.execute(sql % (values,))
+ parts = [pre]
+ for args in page:
+ parts.append(cur.mogrify(template, args))
+ parts.append(b',')
+ parts[-1] = post
+ cur.execute(b''.join(parts))
+
+
+def _split_sql(sql):
+ """Split *sql* on a single ``%s`` placeholder.
+
+ Return a (pre, post) pair around the ``%s``, with ``%%`` -> ``%`` replacement.
+ """
+ curr = pre = []
+ post = []
+ tokens = _re.split(br'(%.)', sql)
+ for token in tokens:
+ if len(token) != 2 or token[:1] != b'%':
+ curr.append(token)
+ continue
+
+ if token[1:] == b's':
+ if curr is pre:
+ curr = post
+ else:
+ raise ValueError(
+ "the query contains more than one '%s' placeholder")
+ elif token[1:] == b'%':
+ curr.append(b'%')
+ else:
+ raise ValueError("unsupported format character: '%s'"
+ % token[1:].decode('ascii', 'replace'))
+
+ if curr is pre:
+ raise ValueError("the query doesn't contain any '%s' placeholder")
+
+ return b''.join(pre), b''.join(post)
diff --git a/tests/__init__.py b/tests/__init__.py
index 1a24099..35837e8 100755
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -37,6 +37,7 @@ import test_cursor
import test_dates
import test_errcodes
import test_extras_dictcursor
+import test_fast_executemany
import test_green
import test_ipaddress
import test_lobject
@@ -74,6 +75,7 @@ def test_suite():
suite.addTest(test_dates.test_suite())
suite.addTest(test_errcodes.test_suite())
suite.addTest(test_extras_dictcursor.test_suite())
+ suite.addTest(test_fast_executemany.test_suite())
suite.addTest(test_green.test_suite())
suite.addTest(test_ipaddress.test_suite())
suite.addTest(test_lobject.test_suite())
diff --git a/tests/test_fast_executemany.py b/tests/test_fast_executemany.py
new file mode 100755
index 0000000..9222274
--- /dev/null
+++ b/tests/test_fast_executemany.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python
+#
+# test_fast_executemany.py - tests for fast executemany implementations
+#
+# Copyright (C) 2017 Daniele Varrazzo <daniele.varrazzo@gmail.com>
+#
+# psycopg2 is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
+# License for more details.
+
+import unittest
+from datetime import date
+
+from testutils import ConnectingTestCase
+
+import psycopg2
+import psycopg2.extras
+import psycopg2.extensions as ext
+
+
+class TestPaginate(unittest.TestCase):
+ def test_paginate(self):
+ def pag(seq):
+ return psycopg2.extras._paginate(seq, 100)
+
+ self.assertEqual(list(pag([])), [])
+ self.assertEqual(list(pag([1])), [[1]])
+ self.assertEqual(list(pag(range(99))), [list(range(99))])
+ self.assertEqual(list(pag(range(100))), [list(range(100))])
+ self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
+ self.assertEqual(
+ list(pag(range(200))), [list(range(100)), list(range(100, 200))])
+ self.assertEqual(
+ list(pag(range(1000))),
+ [list(range(i * 100, (i + 1) * 100)) for i in range(10)])
+
+
+class FastExecuteTestMixin(object):
+ def setUp(self):
+ super(FastExecuteTestMixin, self).setUp()
+ cur = self.conn.cursor()
+ cur.execute("""create table testfast (
+ id serial primary key, date date, val int, data text)""")
+
+
+class TestExecuteBatch(FastExecuteTestMixin, ConnectingTestCase):
+ def test_empty(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ [])
+ cur.execute("select * from testfast order by id")
+ self.assertEqual(cur.fetchall(), [])
+
+ def test_one(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ iter([(1, 10)]))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(1, 10)])
+
+ def test_tuples(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, date, val) values (%s, %s, %s)",
+ ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
+ cur.execute("select id, date, val from testfast order by id")
+ self.assertEqual(cur.fetchall(),
+ [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
+
+ def test_many(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ ((i, i * 10) for i in range(1000)))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
+
+ def test_pages(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ ((i, i * 10) for i in range(25)),
+ page_size=10)
+
+ # last command was 5 statements
+ self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4)
+
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
+
+ def test_unicode(self):
+ cur = self.conn.cursor()
+ ext.register_type(ext.UNICODE, cur)
+ snowman = u"\u2603"
+
+ # unicode in statement
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
+ [(1, 'x')])
+ cur.execute("select id, data from testfast where id = 1")
+ self.assertEqual(cur.fetchone(), (1, 'x'))
+
+ # unicode in data
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, data) values (%s, %s)",
+ [(2, snowman)])
+ cur.execute("select id, data from testfast where id = 2")
+ self.assertEqual(cur.fetchone(), (2, snowman))
+
+ # unicode in both
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
+ [(3, snowman)])
+ cur.execute("select id, data from testfast where id = 3")
+ self.assertEqual(cur.fetchone(), (3, snowman))
+
+
+class TestExecuteValuse(FastExecuteTestMixin, ConnectingTestCase):
+ def test_empty(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ [])
+ cur.execute("select * from testfast order by id")
+ self.assertEqual(cur.fetchall(), [])
+
+ def test_one(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ iter([(1, 10)]))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(1, 10)])
+
+ def test_tuples(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, date, val) values %s",
+ ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
+ cur.execute("select id, date, val from testfast order by id")
+ self.assertEqual(cur.fetchall(),
+ [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
+
+ def test_dicts(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, date, val) values %s",
+ (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
+ for i in range(10)),
+ template='(%(id)s, %(date)s, %(val)s)')
+ cur.execute("select id, date, val from testfast order by id")
+ self.assertEqual(cur.fetchall(),
+ [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
+
+ def test_many(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ ((i, i * 10) for i in range(1000)))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
+
+ def test_pages(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ ((i, i * 10) for i in range(25)),
+ page_size=10)
+
+ # last statement was 5 tuples (one parens is for the fields list)
+ self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
+
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
+
+ def test_unicode(self):
+ cur = self.conn.cursor()
+ ext.register_type(ext.UNICODE, cur)
+ snowman = u"\u2603"
+
+ # unicode in statement
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %%s -- %s" % snowman,
+ [(1, 'x')])
+ cur.execute("select id, data from testfast where id = 1")
+ self.assertEqual(cur.fetchone(), (1, 'x'))
+
+ # unicode in data
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %s",
+ [(2, snowman)])
+ cur.execute("select id, data from testfast where id = 2")
+ self.assertEqual(cur.fetchone(), (2, snowman))
+
+ # unicode in both
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %%s -- %s" % snowman,
+ [(3, snowman)])
+ cur.execute("select id, data from testfast where id = 3")
+ self.assertEqual(cur.fetchone(), (3, snowman))
+
+ def test_invalid_sql(self):
+ cur = self.conn.cursor()
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert", [])
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert %s and %s", [])
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert %f", [])
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert %f %s", [])
+
+ def test_percent_escape(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %s -- a%%b",
+ [(1, 'hi')])
+ self.assert_(b'a%%b' not in cur.query)
+ self.assert_(b'a%b' in cur.query)
+
+ cur.execute("select id, data from testfast")
+ self.assertEqual(cur.fetchall(), [(1, 'hi')])
+
+
+def test_suite():
+ return unittest.TestLoader().loadTestsFromName(__name__)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py
index a584c86..8e61561 100755
--- a/tests/test_types_extras.py
+++ b/tests/test_types_extras.py
@@ -1766,184 +1766,6 @@ class RangeCasterTestCase(ConnectingTestCase):
decorate_all_tests(RangeCasterTestCase, skip_if_no_range)
-class TestFastExecute(ConnectingTestCase):
- def setUp(self):
- super(TestFastExecute, self).setUp()
- cur = self.conn.cursor()
- cur.execute("""create table testfast (
- id serial primary key, date date, val int, data text)""")
-
- def test_paginate(self):
- def pag(seq):
- return psycopg2.extras._paginate(seq, 100)
-
- self.assertEqual(list(pag([])), [])
- self.assertEqual(list(pag([1])), [[1]])
- self.assertEqual(list(pag(range(99))), [list(range(99))])
- self.assertEqual(list(pag(range(100))), [list(range(100))])
- self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
- self.assertEqual(
- list(pag(range(200))), [list(range(100)), list(range(100, 200))])
- self.assertEqual(
- list(pag(range(1000))),
- [list(range(i * 100, (i + 1) * 100)) for i in range(10)])
-
- def test_execute_batch_empty(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- [])
- cur.execute("select * from testfast order by id")
- self.assertEqual(cur.fetchall(), [])
-
- def test_execute_batch_one(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- iter([(1, 10)]))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(1, 10)])
-
- def test_execute_batch_tuples(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, date, val) values (%s, %s, %s)",
- ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
- cur.execute("select id, date, val from testfast order by id")
- self.assertEqual(cur.fetchall(),
- [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
-
- def test_execute_batch_many(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- ((i, i * 10) for i in range(1000)))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
-
- def test_execute_batch_pages(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- ((i, i * 10) for i in range(25)),
- page_size=10)
-
- # last command was 5 statements
- self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4)
-
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
-
- def test_execute_batch_unicode(self):
- cur = self.conn.cursor()
- ext.register_type(ext.UNICODE, cur)
- snowman = u"\u2603"
-
- # unicode in statement
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
- [(1, 'x')])
- cur.execute("select id, data from testfast where id = 1")
- self.assertEqual(cur.fetchone(), (1, 'x'))
-
- # unicode in data
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, data) values (%s, %s)",
- [(2, snowman)])
- cur.execute("select id, data from testfast where id = 2")
- self.assertEqual(cur.fetchone(), (2, snowman))
-
- # unicode in both
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
- [(3, snowman)])
- cur.execute("select id, data from testfast where id = 3")
- self.assertEqual(cur.fetchone(), (3, snowman))
-
- def test_execute_values_empty(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- [])
- cur.execute("select * from testfast order by id")
- self.assertEqual(cur.fetchall(), [])
-
- def test_execute_values_one(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- iter([(1, 10)]))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(1, 10)])
-
- def test_execute_values_tuples(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, date, val) values %s",
- ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
- cur.execute("select id, date, val from testfast order by id")
- self.assertEqual(cur.fetchall(),
- [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
-
- def test_execute_values_dicts(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, date, val) values %s",
- (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
- for i in range(10)),
- template='(%(id)s, %(date)s, %(val)s)')
- cur.execute("select id, date, val from testfast order by id")
- self.assertEqual(cur.fetchall(),
- [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
-
- def test_execute_values_many(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- ((i, i * 10) for i in range(1000)))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
-
- def test_execute_values_pages(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- ((i, i * 10) for i in range(25)),
- page_size=10)
-
- # last statement was 5 tuples (one parens is for the fields list)
- self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
-
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
-
- def test_execute_values_unicode(self):
- cur = self.conn.cursor()
- ext.register_type(ext.UNICODE, cur)
- snowman = u"\u2603"
-
- # unicode in statement
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, data) values %%s -- %s" % snowman,
- [(1, 'x')])
- cur.execute("select id, data from testfast where id = 1")
- self.assertEqual(cur.fetchone(), (1, 'x'))
-
- # unicode in data
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, data) values %s",
- [(2, snowman)])
- cur.execute("select id, data from testfast where id = 2")
- self.assertEqual(cur.fetchone(), (2, snowman))
-
- # unicode in both
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, data) values %%s -- %s" % snowman,
- [(3, snowman)])
- cur.execute("select id, data from testfast where id = 3")
- self.assertEqual(cur.fetchone(), (3, snowman))
-
-
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)