summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-03 04:40:34 +0000
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-03 04:40:34 +0000
commit6e89db020ca9fd93260bfff7aadef64dd4535553 (patch)
treed66779d81133fc496e787271370f44546835a1db
parentd8b1fbd9052946e585967a3c711702a7417ea0e4 (diff)
parent95226baa9b3e9ccff50de053a93b2ea53bb1e25c (diff)
downloadpsycopg2-6e89db020ca9fd93260bfff7aadef64dd4535553.tar.gz
Merge branch 'fast-executemany'
-rw-r--r--lib/extras.py97
-rwxr-xr-xtests/__init__.py2
-rwxr-xr-xtests/test_fast_executemany.py237
-rwxr-xr-xtests/test_types_extras.py178
4 files changed, 316 insertions, 198 deletions
diff --git a/lib/extras.py b/lib/extras.py
index 2d26402..38ca17a 100644
--- a/lib/extras.py
+++ b/lib/extras.py
@@ -1168,11 +1168,15 @@ def execute_batch(cur, sql, argslist, page_size=100):
Execute *sql* several times, against all parameters set (sequences or
mappings) found in *argslist*.
- The function is semantically similar to `~cursor.executemany()`, but has a
- different implementation: Psycopg will join the statements into fewer
- multi-statement commands, reducing the number of server roundtrips,
- resulting in better performances. Every command contains at most
- *page_size* statements.
+ The function is semantically similar to
+
+ .. parsed-literal::
+
+ *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ )
+
+ but has a different implementation: Psycopg will join the statements into
+ fewer multi-statement commands, each one containing at most *page_size*
+ statements, resulting in a reduced number of server roundtrips.
"""
for page in _paginate(argslist, page_size=page_size):
@@ -1183,18 +1187,30 @@ def execute_batch(cur, sql, argslist, page_size=100):
def execute_values(cur, sql, argslist, template=None, page_size=100):
'''Execute a statement using :sql:`VALUES` with a sequence of parameters.
- *sql* must contain a single ``%s`` placeholder, which will be replaced by a
- `VALUES list`__. Every statement will contain at most *page_size* sets of
- arguments.
+ :param cur: the cursor to use to execute the query.
- .. __: https://www.postgresql.org/docs/current/static/queries-values.html
+ :param sql: the query to execute. It must contain a single ``%s``
+ placeholder, which will be replaced by a `VALUES list`__.
+ Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
+
+ :param argslist: sequence of sequences or dictionaries with the arguments
+ to send to the query. The type and content must be consistent with
+ *template*.
- *template* is the part merged to the arguments, so it should be compatible
- with the content of *argslist* (it should contain the right number of
- arguments if *argslist* is a sequence of sequences, or compatible names if
- *argslist* is a sequence of mappings). If not specified, assume the
- arguments are sequence and use a simple positional template (i.e.
- ``(%s, %s, ...)``).
+ :param template: the snippet to merge to every item in *argslist* to
+ compose the query. If *argslist* items are sequences it should contain
+ positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``"
+ if there are constants value...); If *argslist* is items are mapping
+ it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
+ If not specified, assume the arguments are sequence and use a simple
+ positional template (i.e. ``(%s, %s, ...)``), with the number of
+ placeholders sniffed by the first element in *argslist*.
+
+ :param page_size: maximum number of *argslist* items to include in every
+ statement. If there are more items the function will execute more than
+ one statement.
+
+ .. __: https://www.postgresql.org/docs/current/static/queries-values.html
While :sql:`INSERT` is an obvious candidate for this function it is
possible to use it with other statements, for example::
@@ -1216,10 +1232,51 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
[(1, 20, 3), (4, 50, 6), (7, 8, 9)])
'''
+ # we can't just use sql % vals because vals is bytes: if sql is bytes
+ # there will be some decoding error because of stupid codec used, and Py3
+ # doesn't implement % on bytes.
+ if not isinstance(sql, bytes):
+ sql = sql.encode(_ext.encodings[cur.connection.encoding])
+ pre, post = _split_sql(sql)
+
for page in _paginate(argslist, page_size=page_size):
if template is None:
- template = '(%s)' % ','.join(['%s'] * len(page[0]))
- values = b",".join(cur.mogrify(template, args) for args in page)
- if isinstance(values, bytes):
- values = values.decode(_ext.encodings[cur.connection.encoding])
- cur.execute(sql % (values,))
+ template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
+ parts = pre[:]
+ for args in page:
+ parts.append(cur.mogrify(template, args))
+ parts.append(b',')
+ parts[-1:] = post
+ cur.execute(b''.join(parts))
+
+
+def _split_sql(sql):
+ """Split *sql* on a single ``%s`` placeholder.
+
+ Split on the %s, perform %% replacement and return pre, post lists of
+ snippets.
+ """
+ curr = pre = []
+ post = []
+ tokens = _re.split(br'(%.)', sql)
+ for token in tokens:
+ if len(token) != 2 or token[:1] != b'%':
+ curr.append(token)
+ continue
+
+ if token[1:] == b's':
+ if curr is pre:
+ curr = post
+ else:
+ raise ValueError(
+ "the query contains more than one '%s' placeholder")
+ elif token[1:] == b'%':
+ curr.append(b'%')
+ else:
+ raise ValueError("unsupported format character: '%s'"
+ % token[1:].decode('ascii', 'replace'))
+
+ if curr is pre:
+ raise ValueError("the query doesn't contain any '%s' placeholder")
+
+ return pre, post
diff --git a/tests/__init__.py b/tests/__init__.py
index 1a24099..35837e8 100755
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -37,6 +37,7 @@ import test_cursor
import test_dates
import test_errcodes
import test_extras_dictcursor
+import test_fast_executemany
import test_green
import test_ipaddress
import test_lobject
@@ -74,6 +75,7 @@ def test_suite():
suite.addTest(test_dates.test_suite())
suite.addTest(test_errcodes.test_suite())
suite.addTest(test_extras_dictcursor.test_suite())
+ suite.addTest(test_fast_executemany.test_suite())
suite.addTest(test_green.test_suite())
suite.addTest(test_ipaddress.test_suite())
suite.addTest(test_lobject.test_suite())
diff --git a/tests/test_fast_executemany.py b/tests/test_fast_executemany.py
new file mode 100755
index 0000000..9222274
--- /dev/null
+++ b/tests/test_fast_executemany.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python
+#
+# test_fast_executemany.py - tests for fast executemany implementations
+#
+# Copyright (C) 2017 Daniele Varrazzo <daniele.varrazzo@gmail.com>
+#
+# psycopg2 is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
+# License for more details.
+
+import unittest
+from datetime import date
+
+from testutils import ConnectingTestCase
+
+import psycopg2
+import psycopg2.extras
+import psycopg2.extensions as ext
+
+
+class TestPaginate(unittest.TestCase):
+ def test_paginate(self):
+ def pag(seq):
+ return psycopg2.extras._paginate(seq, 100)
+
+ self.assertEqual(list(pag([])), [])
+ self.assertEqual(list(pag([1])), [[1]])
+ self.assertEqual(list(pag(range(99))), [list(range(99))])
+ self.assertEqual(list(pag(range(100))), [list(range(100))])
+ self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
+ self.assertEqual(
+ list(pag(range(200))), [list(range(100)), list(range(100, 200))])
+ self.assertEqual(
+ list(pag(range(1000))),
+ [list(range(i * 100, (i + 1) * 100)) for i in range(10)])
+
+
+class FastExecuteTestMixin(object):
+ def setUp(self):
+ super(FastExecuteTestMixin, self).setUp()
+ cur = self.conn.cursor()
+ cur.execute("""create table testfast (
+ id serial primary key, date date, val int, data text)""")
+
+
+class TestExecuteBatch(FastExecuteTestMixin, ConnectingTestCase):
+ def test_empty(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ [])
+ cur.execute("select * from testfast order by id")
+ self.assertEqual(cur.fetchall(), [])
+
+ def test_one(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ iter([(1, 10)]))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(1, 10)])
+
+ def test_tuples(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, date, val) values (%s, %s, %s)",
+ ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
+ cur.execute("select id, date, val from testfast order by id")
+ self.assertEqual(cur.fetchall(),
+ [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
+
+ def test_many(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ ((i, i * 10) for i in range(1000)))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
+
+ def test_pages(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, val) values (%s, %s)",
+ ((i, i * 10) for i in range(25)),
+ page_size=10)
+
+ # last command was 5 statements
+ self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4)
+
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
+
+ def test_unicode(self):
+ cur = self.conn.cursor()
+ ext.register_type(ext.UNICODE, cur)
+ snowman = u"\u2603"
+
+ # unicode in statement
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
+ [(1, 'x')])
+ cur.execute("select id, data from testfast where id = 1")
+ self.assertEqual(cur.fetchone(), (1, 'x'))
+
+ # unicode in data
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, data) values (%s, %s)",
+ [(2, snowman)])
+ cur.execute("select id, data from testfast where id = 2")
+ self.assertEqual(cur.fetchone(), (2, snowman))
+
+ # unicode in both
+ psycopg2.extras.execute_batch(cur,
+ "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
+ [(3, snowman)])
+ cur.execute("select id, data from testfast where id = 3")
+ self.assertEqual(cur.fetchone(), (3, snowman))
+
+
+class TestExecuteValuse(FastExecuteTestMixin, ConnectingTestCase):
+ def test_empty(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ [])
+ cur.execute("select * from testfast order by id")
+ self.assertEqual(cur.fetchall(), [])
+
+ def test_one(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ iter([(1, 10)]))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(1, 10)])
+
+ def test_tuples(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, date, val) values %s",
+ ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
+ cur.execute("select id, date, val from testfast order by id")
+ self.assertEqual(cur.fetchall(),
+ [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
+
+ def test_dicts(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, date, val) values %s",
+ (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
+ for i in range(10)),
+ template='(%(id)s, %(date)s, %(val)s)')
+ cur.execute("select id, date, val from testfast order by id")
+ self.assertEqual(cur.fetchall(),
+ [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
+
+ def test_many(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ ((i, i * 10) for i in range(1000)))
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
+
+ def test_pages(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, val) values %s",
+ ((i, i * 10) for i in range(25)),
+ page_size=10)
+
+ # last statement was 5 tuples (one parens is for the fields list)
+ self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
+
+ cur.execute("select id, val from testfast order by id")
+ self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
+
+ def test_unicode(self):
+ cur = self.conn.cursor()
+ ext.register_type(ext.UNICODE, cur)
+ snowman = u"\u2603"
+
+ # unicode in statement
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %%s -- %s" % snowman,
+ [(1, 'x')])
+ cur.execute("select id, data from testfast where id = 1")
+ self.assertEqual(cur.fetchone(), (1, 'x'))
+
+ # unicode in data
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %s",
+ [(2, snowman)])
+ cur.execute("select id, data from testfast where id = 2")
+ self.assertEqual(cur.fetchone(), (2, snowman))
+
+ # unicode in both
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %%s -- %s" % snowman,
+ [(3, snowman)])
+ cur.execute("select id, data from testfast where id = 3")
+ self.assertEqual(cur.fetchone(), (3, snowman))
+
+ def test_invalid_sql(self):
+ cur = self.conn.cursor()
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert", [])
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert %s and %s", [])
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert %f", [])
+ self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
+ "insert %f %s", [])
+
+ def test_percent_escape(self):
+ cur = self.conn.cursor()
+ psycopg2.extras.execute_values(cur,
+ "insert into testfast (id, data) values %s -- a%%b",
+ [(1, 'hi')])
+ self.assert_(b'a%%b' not in cur.query)
+ self.assert_(b'a%b' in cur.query)
+
+ cur.execute("select id, data from testfast")
+ self.assertEqual(cur.fetchall(), [(1, 'hi')])
+
+
+def test_suite():
+ return unittest.TestLoader().loadTestsFromName(__name__)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py
index 8fe3bae..f28c5c2 100755
--- a/tests/test_types_extras.py
+++ b/tests/test_types_extras.py
@@ -1766,184 +1766,6 @@ class RangeCasterTestCase(ConnectingTestCase):
decorate_all_tests(RangeCasterTestCase, skip_if_no_range)
-class TestFastExecute(ConnectingTestCase):
- def setUp(self):
- super(TestFastExecute, self).setUp()
- cur = self.conn.cursor()
- cur.execute("""create table testfast (
- id serial primary key, date date, val int, data text)""")
-
- def test_paginate(self):
- def pag(seq):
- return psycopg2.extras._paginate(seq, 100)
-
- self.assertEqual(list(pag([])), [])
- self.assertEqual(list(pag([1])), [[1]])
- self.assertEqual(list(pag(range(99))), [list(range(99))])
- self.assertEqual(list(pag(range(100))), [list(range(100))])
- self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
- self.assertEqual(
- list(pag(range(200))), [list(range(100)), list(range(100, 200))])
- self.assertEqual(
- list(pag(range(1000))),
- [list(range(i * 100, (i + 1) * 100)) for i in range(10)])
-
- def test_execute_batch_empty(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- [])
- cur.execute("select * from testfast order by id")
- self.assertEqual(cur.fetchall(), [])
-
- def test_execute_batch_one(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- iter([(1, 10)]))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(1, 10)])
-
- def test_execute_batch_tuples(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, date, val) values (%s, %s, %s)",
- ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
- cur.execute("select id, date, val from testfast order by id")
- self.assertEqual(cur.fetchall(),
- [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
-
- def test_execute_batch_many(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- ((i, i * 10) for i in range(1000)))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
-
- def test_execute_batch_pages(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, val) values (%s, %s)",
- ((i, i * 10) for i in range(25)),
- page_size=10)
-
- # last command was 5 statements
- self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4)
-
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
-
- def test_execute_batch_unicode(self):
- cur = self.conn.cursor()
- ext.register_type(ext.UNICODE, cur)
- snowman = u"\u2603"
-
- # unicode in statement
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
- [(1, 'x')])
- cur.execute("select id, data from testfast where id = 1")
- self.assertEqual(cur.fetchone(), (1, 'x'))
-
- # unicode in data
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, data) values (%s, %s)",
- [(2, snowman)])
- cur.execute("select id, data from testfast where id = 2")
- self.assertEqual(cur.fetchone(), (2, snowman))
-
- # unicode in both
- psycopg2.extras.execute_batch(cur,
- "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
- [(3, snowman)])
- cur.execute("select id, data from testfast where id = 3")
- self.assertEqual(cur.fetchone(), (3, snowman))
-
- def test_execute_values_empty(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- [])
- cur.execute("select * from testfast order by id")
- self.assertEqual(cur.fetchall(), [])
-
- def test_execute_values_one(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- iter([(1, 10)]))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(1, 10)])
-
- def test_execute_values_tuples(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, date, val) values %s",
- ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
- cur.execute("select id, date, val from testfast order by id")
- self.assertEqual(cur.fetchall(),
- [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
-
- def test_execute_values_dicts(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, date, val) values %s",
- (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
- for i in range(10)),
- template='(%(id)s, %(date)s, %(val)s)')
- cur.execute("select id, date, val from testfast order by id")
- self.assertEqual(cur.fetchall(),
- [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
-
- def test_execute_values_many(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- ((i, i * 10) for i in range(1000)))
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
-
- def test_execute_values_pages(self):
- cur = self.conn.cursor()
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, val) values %s",
- ((i, i * 10) for i in range(25)),
- page_size=10)
-
- # last statement was 5 tuples (one parens is for the fields list)
- self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
-
- cur.execute("select id, val from testfast order by id")
- self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
-
- def test_execute_values_unicode(self):
- cur = self.conn.cursor()
- ext.register_type(ext.UNICODE, cur)
- snowman = u"\u2603"
-
- # unicode in statement
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, data) values %%s -- %s" % snowman,
- [(1, 'x')])
- cur.execute("select id, data from testfast where id = 1")
- self.assertEqual(cur.fetchone(), (1, 'x'))
-
- # unicode in data
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, data) values %s",
- [(2, snowman)])
- cur.execute("select id, data from testfast where id = 2")
- self.assertEqual(cur.fetchone(), (2, snowman))
-
- # unicode in both
- psycopg2.extras.execute_values(cur,
- "insert into testfast (id, data) values %%s -- %s" % snowman,
- [(3, snowman)])
- cur.execute("select id, data from testfast where id = 3")
- self.assertEqual(cur.fetchone(), (3, snowman))
-
-
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)