From 659910ee8196e04b24ce96d9b4ccb7ed93aed6ac Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 21 Jul 2020 01:42:34 +0100 Subject: Allow most of the async tests to pass on CockroachDB Added function to get crdb version from a connection --- tests/test_async.py | 7 ++++++- tests/testutils.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/test_async.py b/tests/test_async.py index d62eb3b..6738c07 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -33,7 +33,8 @@ import psycopg2 import psycopg2.errors from psycopg2 import extensions as ext -from .testutils import ConnectingTestCase, StringIO, skip_before_postgres, slow +from .testutils import (ConnectingTestCase, StringIO, skip_before_postgres, + crdb_version, slow) class PollableStub(object): @@ -62,6 +63,10 @@ class AsyncTests(ConnectingTestCase): self.wait(self.conn) curs = self.conn.cursor() + if crdb_version(self.sync_conn) is not None: + curs.execute("set experimental_enable_temp_tables = 'on'") + self.wait(curs) + curs.execute(''' CREATE TEMPORARY TABLE table1 ( id int PRIMARY KEY diff --git a/tests/testutils.py b/tests/testutils.py index 26f6cc7..42f940c 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -407,6 +407,38 @@ def skip_if_windows(cls): return decorator(cls) +def crdb_version(conn, __crdb_version=[]): + """ + Return the CockroachDB version if that's the db testing, else None. + + Return the number as an integer similar to PQserverVersion: return + v20.1.3 as 200103. + + Assume all the connections are on the same db: return a chached result on + following runs. + + """ + if __crdb_version: + return __crdb_version[0] + + with conn.cursor() as cur: + try: + cur.execute("show crdb_version") + except psycopg2.ProgrammingError: + __crdb_version.append(None) + else: + sver = cur.fetchone()[0] + m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) + if not m: + raise ValueError( + "can't parse CockroachDB version from %s" % sver) + + ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) + __crdb_version.append(ver) + + return __crdb_version[0] + + class py3_raises_typeerror(object): def __enter__(self): pass -- cgit v1.2.1