summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2018-12-27 14:53:12 +0100
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2018-12-27 14:53:12 +0100
commitf3695e36c700d988052a6d3b8cd1dabe8efb9a00 (patch)
treea7b570e4ebafc6fff06ad3ffacaeb9f052759d20 /lib
parent25fc044d131c2afe0a1e58ee26673f055cef66f6 (diff)
parent7c8d2f484e19ddd1290b0c5ee9b3bf003bc7db79 (diff)
downloadpsycopg2-f3695e36c700d988052a6d3b8cd1dabe8efb9a00.tar.gz
Merge remote-tracking branch 'eternalflow/execute-values-returning-clause-support'
Diffstat (limited to 'lib')
-rw-r--r--lib/extras.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/lib/extras.py b/lib/extras.py
index 0764edf..76bea85 100644
--- a/lib/extras.py
+++ b/lib/extras.py
@@ -1198,7 +1198,7 @@ def execute_batch(cur, sql, argslist, page_size=100):
cur.execute(b";".join(sqls))
-def execute_values(cur, sql, argslist, template=None, page_size=100):
+def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False):
'''Execute a statement using :sql:`VALUES` with a sequence of parameters.
:param cur: the cursor to use to execute the query.
@@ -1229,6 +1229,10 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
statement. If there are more items the function will execute more than
one statement.
+ :param fetch: if `!True` return the query results into a list (like in a
+ `~cursor.fetchall()`). Useful for queries with :sql:`RETURNING`
+ clause.
+
.. __: https://www.postgresql.org/docs/current/static/queries-values.html
After the execution of the function the `cursor.rowcount` property will
@@ -1265,6 +1269,7 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
sql = sql.encode(_ext.encodings[cur.connection.encoding])
pre, post = _split_sql(sql)
+ result = [] if fetch else None
for page in _paginate(argslist, page_size=page_size):
if template is None:
template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
@@ -1274,6 +1279,10 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
parts.append(b',')
parts[-1:] = post
cur.execute(b''.join(parts))
+ if fetch:
+ result.extend(cur.fetchall())
+
+ return result
def _split_sql(sql):