summaryrefslogtreecommitdiff
path: root/lib/extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/extras.py')
-rw-r--r--lib/extras.py97
1 files changed, 77 insertions, 20 deletions
diff --git a/lib/extras.py b/lib/extras.py
index 2d26402..38ca17a 100644
--- a/lib/extras.py
+++ b/lib/extras.py
@@ -1168,11 +1168,15 @@ def execute_batch(cur, sql, argslist, page_size=100):
Execute *sql* several times, against all parameters set (sequences or
mappings) found in *argslist*.
- The function is semantically similar to `~cursor.executemany()`, but has a
- different implementation: Psycopg will join the statements into fewer
- multi-statement commands, reducing the number of server roundtrips,
- resulting in better performances. Every command contains at most
- *page_size* statements.
+ The function is semantically similar to
+
+ .. parsed-literal::
+
+ *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ )
+
+ but has a different implementation: Psycopg will join the statements into
+ fewer multi-statement commands, each one containing at most *page_size*
+ statements, resulting in a reduced number of server roundtrips.
"""
for page in _paginate(argslist, page_size=page_size):
@@ -1183,18 +1187,30 @@ def execute_batch(cur, sql, argslist, page_size=100):
def execute_values(cur, sql, argslist, template=None, page_size=100):
'''Execute a statement using :sql:`VALUES` with a sequence of parameters.
- *sql* must contain a single ``%s`` placeholder, which will be replaced by a
- `VALUES list`__. Every statement will contain at most *page_size* sets of
- arguments.
+ :param cur: the cursor to use to execute the query.
- .. __: https://www.postgresql.org/docs/current/static/queries-values.html
+ :param sql: the query to execute. It must contain a single ``%s``
+ placeholder, which will be replaced by a `VALUES list`__.
+ Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
+
+ :param argslist: sequence of sequences or dictionaries with the arguments
+ to send to the query. The type and content must be consistent with
+ *template*.
- *template* is the part merged to the arguments, so it should be compatible
- with the content of *argslist* (it should contain the right number of
- arguments if *argslist* is a sequence of sequences, or compatible names if
- *argslist* is a sequence of mappings). If not specified, assume the
- arguments are sequence and use a simple positional template (i.e.
- ``(%s, %s, ...)``).
+ :param template: the snippet to merge to every item in *argslist* to
+ compose the query. If *argslist* items are sequences it should contain
+ positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``"
+ if there are constants value...); If *argslist* is items are mapping
+ it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
+ If not specified, assume the arguments are sequence and use a simple
+ positional template (i.e. ``(%s, %s, ...)``), with the number of
+ placeholders sniffed by the first element in *argslist*.
+
+ :param page_size: maximum number of *argslist* items to include in every
+ statement. If there are more items the function will execute more than
+ one statement.
+
+ .. __: https://www.postgresql.org/docs/current/static/queries-values.html
While :sql:`INSERT` is an obvious candidate for this function it is
possible to use it with other statements, for example::
@@ -1216,10 +1232,51 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
[(1, 20, 3), (4, 50, 6), (7, 8, 9)])
'''
+ # we can't just use sql % vals because vals is bytes: if sql is bytes
+ # there will be some decoding error because of stupid codec used, and Py3
+ # doesn't implement % on bytes.
+ if not isinstance(sql, bytes):
+ sql = sql.encode(_ext.encodings[cur.connection.encoding])
+ pre, post = _split_sql(sql)
+
for page in _paginate(argslist, page_size=page_size):
if template is None:
- template = '(%s)' % ','.join(['%s'] * len(page[0]))
- values = b",".join(cur.mogrify(template, args) for args in page)
- if isinstance(values, bytes):
- values = values.decode(_ext.encodings[cur.connection.encoding])
- cur.execute(sql % (values,))
+ template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
+ parts = pre[:]
+ for args in page:
+ parts.append(cur.mogrify(template, args))
+ parts.append(b',')
+ parts[-1:] = post
+ cur.execute(b''.join(parts))
+
+
+def _split_sql(sql):
+ """Split *sql* on a single ``%s`` placeholder.
+
+ Split on the %s, perform %% replacement and return pre, post lists of
+ snippets.
+ """
+ curr = pre = []
+ post = []
+ tokens = _re.split(br'(%.)', sql)
+ for token in tokens:
+ if len(token) != 2 or token[:1] != b'%':
+ curr.append(token)
+ continue
+
+ if token[1:] == b's':
+ if curr is pre:
+ curr = post
+ else:
+ raise ValueError(
+ "the query contains more than one '%s' placeholder")
+ elif token[1:] == b'%':
+ curr.append(b'%')
+ else:
+ raise ValueError("unsupported format character: '%s'"
+ % token[1:].decode('ascii', 'replace'))
+
+ if curr is pre:
+ raise ValueError("the query doesn't contain any '%s' placeholder")
+
+ return pre, post