diff options
Diffstat (limited to 'lib/extras.py')
-rw-r--r-- | lib/extras.py | 97 |
1 files changed, 77 insertions, 20 deletions
diff --git a/lib/extras.py b/lib/extras.py index 2d26402..38ca17a 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -1168,11 +1168,15 @@ def execute_batch(cur, sql, argslist, page_size=100): Execute *sql* several times, against all parameters set (sequences or mappings) found in *argslist*. - The function is semantically similar to `~cursor.executemany()`, but has a - different implementation: Psycopg will join the statements into fewer - multi-statement commands, reducing the number of server roundtrips, - resulting in better performances. Every command contains at most - *page_size* statements. + The function is semantically similar to + + .. parsed-literal:: + + *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ ) + + but has a different implementation: Psycopg will join the statements into + fewer multi-statement commands, each one containing at most *page_size* + statements, resulting in a reduced number of server roundtrips. """ for page in _paginate(argslist, page_size=page_size): @@ -1183,18 +1187,30 @@ def execute_batch(cur, sql, argslist, page_size=100): def execute_values(cur, sql, argslist, template=None, page_size=100): '''Execute a statement using :sql:`VALUES` with a sequence of parameters. - *sql* must contain a single ``%s`` placeholder, which will be replaced by a - `VALUES list`__. Every statement will contain at most *page_size* sets of - arguments. + :param cur: the cursor to use to execute the query. - .. __: https://www.postgresql.org/docs/current/static/queries-values.html + :param sql: the query to execute. It must contain a single ``%s`` + placeholder, which will be replaced by a `VALUES list`__. + Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. + + :param argslist: sequence of sequences or dictionaries with the arguments + to send to the query. The type and content must be consistent with + *template*. - *template* is the part merged to the arguments, so it should be compatible - with the content of *argslist* (it should contain the right number of - arguments if *argslist* is a sequence of sequences, or compatible names if - *argslist* is a sequence of mappings). If not specified, assume the - arguments are sequence and use a simple positional template (i.e. - ``(%s, %s, ...)``). + :param template: the snippet to merge to every item in *argslist* to + compose the query. If *argslist* items are sequences it should contain + positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" + if there are constants value...); If *argslist* is items are mapping + it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). + If not specified, assume the arguments are sequence and use a simple + positional template (i.e. ``(%s, %s, ...)``), with the number of + placeholders sniffed by the first element in *argslist*. + + :param page_size: maximum number of *argslist* items to include in every + statement. If there are more items the function will execute more than + one statement. + + .. __: https://www.postgresql.org/docs/current/static/queries-values.html While :sql:`INSERT` is an obvious candidate for this function it is possible to use it with other statements, for example:: @@ -1216,10 +1232,51 @@ def execute_values(cur, sql, argslist, template=None, page_size=100): [(1, 20, 3), (4, 50, 6), (7, 8, 9)]) ''' + # we can't just use sql % vals because vals is bytes: if sql is bytes + # there will be some decoding error because of stupid codec used, and Py3 + # doesn't implement % on bytes. + if not isinstance(sql, bytes): + sql = sql.encode(_ext.encodings[cur.connection.encoding]) + pre, post = _split_sql(sql) + for page in _paginate(argslist, page_size=page_size): if template is None: - template = '(%s)' % ','.join(['%s'] * len(page[0])) - values = b",".join(cur.mogrify(template, args) for args in page) - if isinstance(values, bytes): - values = values.decode(_ext.encodings[cur.connection.encoding]) - cur.execute(sql % (values,)) + template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' + parts = pre[:] + for args in page: + parts.append(cur.mogrify(template, args)) + parts.append(b',') + parts[-1:] = post + cur.execute(b''.join(parts)) + + +def _split_sql(sql): + """Split *sql* on a single ``%s`` placeholder. + + Split on the %s, perform %% replacement and return pre, post lists of + snippets. + """ + curr = pre = [] + post = [] + tokens = _re.split(br'(%.)', sql) + for token in tokens: + if len(token) != 2 or token[:1] != b'%': + curr.append(token) + continue + + if token[1:] == b's': + if curr is pre: + curr = post + else: + raise ValueError( + "the query contains more than one '%s' placeholder") + elif token[1:] == b'%': + curr.append(b'%') + else: + raise ValueError("unsupported format character: '%s'" + % token[1:].decode('ascii', 'replace')) + + if curr is pre: + raise ValueError("the query doesn't contain any '%s' placeholder") + + return pre, post |