summaryrefslogtreecommitdiff
path: root/backup
diff options
context:
space:
mode:
authorLingxian Kong <anlin.kong@gmail.com>2020-09-02 10:10:23 +1200
committerLingxian Kong <anlin.kong@gmail.com>2020-09-07 20:40:56 +1200
commit4fb41b5198c865b46a02dd72501d12e60ec10dd6 (patch)
tree663e32e8cf216201c17d1dc25201d992eb249787 /backup
parent768ec34dfef660f133f87218a6246a9ce111bcb5 (diff)
downloadtrove-4fb41b5198c865b46a02dd72501d12e60ec10dd6.tar.gz
Postgresql: Backup and restore
Change-Id: Icf08b7dc82ce501d82b45cf5412256a43716b6ae
Diffstat (limited to 'backup')
-rw-r--r--backup/Dockerfile5
-rw-r--r--backup/drivers/base.py10
-rw-r--r--backup/drivers/innobackupex.py1
-rw-r--r--backup/drivers/mariabackup.py1
-rw-r--r--backup/drivers/mysql_base.py7
-rw-r--r--backup/drivers/postgres.py249
-rwxr-xr-xbackup/install.sh9
-rw-r--r--backup/main.py28
-rw-r--r--backup/requirements.txt2
-rw-r--r--backup/storage/swift.py8
-rw-r--r--backup/utils/__init__.py46
-rw-r--r--backup/utils/postgresql.py53
12 files changed, 396 insertions, 23 deletions
diff --git a/backup/Dockerfile b/backup/Dockerfile
index 86c19ede..38ebb14a 100644
--- a/backup/Dockerfile
+++ b/backup/Dockerfile
@@ -4,8 +4,9 @@ LABEL maintainer="anlin.kong@gmail.com"
ARG DATASTORE="mysql"
ARG APTOPTS="-y -qq --no-install-recommends --allow-unauthenticated"
ARG PERCONA_XTRABACKUP_VERSION=24
-ENV DEBIAN_FRONTEND noninteractive \
- APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1
+
+RUN export DEBIAN_FRONTEND="noninteractive" \
+ && export APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1
RUN apt-get update \
&& apt-get install $APTOPTS gnupg2 lsb-release apt-utils apt-transport-https ca-certificates software-properties-common curl \
diff --git a/backup/drivers/base.py b/backup/drivers/base.py
index 033553bc..20ed75cf 100644
--- a/backup/drivers/base.py
+++ b/backup/drivers/base.py
@@ -27,12 +27,11 @@ class BaseRunner(object):
"""Base class for Backup Strategy implementations."""
# Subclass should provide the commands.
- cmd = None
- restore_cmd = None
- prepare_cmd = None
+ cmd = ''
+ restore_cmd = ''
+ prepare_cmd = ''
encrypt_key = CONF.backup_encryption_key
- default_data_dir = '/var/lib/mysql/data'
def __init__(self, *args, **kwargs):
self.process = None
@@ -43,8 +42,9 @@ class BaseRunner(object):
self.checksum = kwargs.pop('checksum', '')
if 'restore_location' not in kwargs:
- kwargs['restore_location'] = self.default_data_dir
+ kwargs['restore_location'] = self.datadir
self.restore_location = kwargs['restore_location']
+ self.restore_content_length = 0
self.command = self.cmd % kwargs
self.restore_command = (self.decrypt_cmd +
diff --git a/backup/drivers/innobackupex.py b/backup/drivers/innobackupex.py
index e077d497..9bbebc3a 100644
--- a/backup/drivers/innobackupex.py
+++ b/backup/drivers/innobackupex.py
@@ -102,7 +102,6 @@ class InnoBackupExIncremental(InnoBackupEx):
raise AttributeError('lsn attribute missing')
self.parent_location = kwargs.pop('parent_location', '')
self.parent_checksum = kwargs.pop('parent_checksum', '')
- self.restore_content_length = 0
super(InnoBackupExIncremental, self).__init__(*args, **kwargs)
diff --git a/backup/drivers/mariabackup.py b/backup/drivers/mariabackup.py
index e10cca30..dbf3bd07 100644
--- a/backup/drivers/mariabackup.py
+++ b/backup/drivers/mariabackup.py
@@ -56,7 +56,6 @@ class MariaBackupIncremental(MariaBackup):
raise AttributeError('lsn attribute missing')
self.parent_location = kwargs.pop('parent_location', '')
self.parent_checksum = kwargs.pop('parent_checksum', '')
- self.restore_content_length = 0
super(MariaBackupIncremental, self).__init__(*args, **kwargs)
diff --git a/backup/drivers/mysql_base.py b/backup/drivers/mysql_base.py
index 2450daf0..6389cdb9 100644
--- a/backup/drivers/mysql_base.py
+++ b/backup/drivers/mysql_base.py
@@ -27,6 +27,8 @@ LOG = logging.getLogger(__name__)
class MySQLBaseRunner(base.BaseRunner):
def __init__(self, *args, **kwargs):
+ self.datadir = kwargs.pop('db_datadir', '/var/lib/mysql/data')
+
super(MySQLBaseRunner, self).__init__(*args, **kwargs)
@property
@@ -113,8 +115,8 @@ class MySQLBaseRunner(base.BaseRunner):
incremental_dir = None
if 'parent_location' in metadata:
- LOG.info("Restoring parent: %(parent_location)s"
- " checksum: %(parent_checksum)s.", metadata)
+ LOG.info("Restoring parent: %(parent_location)s, "
+ "checksum: %(parent_checksum)s.", metadata)
parent_location = metadata['parent_location']
parent_checksum = metadata['parent_checksum']
@@ -129,6 +131,7 @@ class MySQLBaseRunner(base.BaseRunner):
else:
# The parent (full backup) use the same command from InnobackupEx
# super class and do not set an incremental_dir.
+ LOG.info("Restoring back to full backup.")
command = self.restore_command
self.restore_content_length += self.unpack(location, checksum, command)
diff --git a/backup/drivers/postgres.py b/backup/drivers/postgres.py
new file mode 100644
index 00000000..0b6538bb
--- /dev/null
+++ b/backup/drivers/postgres.py
@@ -0,0 +1,249 @@
+# Copyright 2020 Catalyst Cloud
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+
+from oslo_log import log as logging
+
+from backup import utils
+from backup.drivers import base
+from backup.utils import postgresql as psql_util
+
+LOG = logging.getLogger(__name__)
+
+
+class PgBasebackup(base.BaseRunner):
+ def __init__(self, *args, **kwargs):
+ if not kwargs.get('wal_archive_dir'):
+ raise AttributeError('wal_archive_dir attribute missing')
+ self.wal_archive_dir = kwargs.pop('wal_archive_dir')
+ self.datadir = kwargs.pop(
+ 'db_datadir', '/var/lib/postgresql/data/pgdata')
+
+ self.label = None
+ self.stop_segment = None
+ self.start_segment = None
+ self.start_wal_file = None
+ self.stop_wal_file = None
+ self.checkpoint_location = None
+ self.metadata = {}
+
+ super(PgBasebackup, self).__init__(*args, **kwargs)
+
+ self.restore_command = (f"{self.decrypt_cmd}tar xzf - -C "
+ f"{self.datadir}")
+
+ @property
+ def cmd(self):
+ cmd = (f"pg_basebackup -U postgres -Ft -z --wal-method=fetch "
+ f"--label={self.filename} --pgdata=-")
+ return cmd + self.encrypt_cmd
+
+ @property
+ def manifest(self):
+ """Target file name."""
+ return "%s.tar.gz%s" % (self.filename, self.encrypt_manifest)
+
+ def get_wal_files(self, backup_pos=0):
+ """Return the WAL files since the provided last backup.
+
+ pg_archivebackup depends on alphanumeric sorting to decide wal order,
+ so we'll do so too:
+ https://github.com/postgres/postgres/blob/REL9_4_STABLE/contrib
+ /pg_archivecleanup/pg_archivecleanup.c#L122
+ """
+ backup_file = self.get_backup_file(backup_pos=backup_pos)
+ last_wal = backup_file.split('.')[0]
+ wal_re = re.compile("^[0-9A-F]{24}$")
+ wal_files = [wal_file for wal_file in os.listdir(self.wal_archive_dir)
+ if wal_re.search(wal_file) and wal_file >= last_wal]
+ return wal_files
+
+ def get_backup_file(self, backup_pos=0):
+ """Look for the most recent .backup file that basebackup creates
+
+ :return: a string like 000000010000000000000006.00000168.backup
+ """
+ backup_re = re.compile("[0-9A-F]{24}.*.backup")
+ wal_files = [wal_file for wal_file in os.listdir(self.wal_archive_dir)
+ if backup_re.search(wal_file)]
+ wal_files = sorted(wal_files, reverse=True)
+ if not wal_files:
+ return None
+ return wal_files[backup_pos]
+
+ def get_backup_metadata(self, metadata_file):
+ """Parse the contents of the .backup file"""
+ metadata = {}
+
+ start_re = re.compile(r"START WAL LOCATION: (.*) \(file (.*)\)")
+ stop_re = re.compile(r"STOP WAL LOCATION: (.*) \(file (.*)\)")
+ checkpt_re = re.compile("CHECKPOINT LOCATION: (.*)")
+ label_re = re.compile("LABEL: (.*)")
+
+ with open(metadata_file, 'r') as file:
+ metadata_contents = file.read()
+
+ match = start_re.search(metadata_contents)
+ if match:
+ self.start_segment = match.group(1)
+ metadata['start-segment'] = self.start_segment
+ self.start_wal_file = match.group(2)
+ metadata['start-wal-file'] = self.start_wal_file
+
+ match = stop_re.search(metadata_contents)
+ if match:
+ self.stop_segment = match.group(1)
+ metadata['stop-segment'] = self.stop_segment
+ self.stop_wal_file = match.group(2)
+ metadata['stop-wal-file'] = self.stop_wal_file
+
+ match = checkpt_re.search(metadata_contents)
+ if match:
+ self.checkpoint_location = match.group(1)
+ metadata['checkpoint-location'] = self.checkpoint_location
+
+ match = label_re.search(metadata_contents)
+ if match:
+ self.label = match.group(1)
+ metadata['label'] = self.label
+
+ return metadata
+
+ def get_metadata(self):
+ """Get metadata.
+
+ pg_basebackup may complete, and we arrive here before the
+ history file is written to the wal archive. So we need to
+ handle two possibilities:
+ - this is the first backup, and no history file exists yet
+ - this isn't the first backup, and so the history file we retrieve
+ isn't the one we just ran!
+ """
+ def _metadata_found():
+ backup_file = self.get_backup_file()
+ if not backup_file:
+ return False
+
+ self.metadata = self.get_backup_metadata(
+ os.path.join(self.wal_archive_dir, backup_file))
+ LOG.info("Metadata for backup: %s.", self.metadata)
+ return self.metadata['label'] == self.filename
+
+ try:
+ LOG.debug("Polling for backup metadata... ")
+ utils.poll_until(_metadata_found, sleep_time=5, time_out=60)
+ except Exception as e:
+ raise RuntimeError(f"Failed to get backup metadata for backup "
+ f"{self.filename}: {str(e)}")
+
+ return self.metadata
+
+ def check_process(self):
+ # If any of the below variables were not set by either metadata()
+ # or direct retrieval from the pgsql backup commands, then something
+ # has gone wrong
+ if not self.start_segment or not self.start_wal_file:
+ LOG.error("Unable to determine starting WAL file/segment")
+ return False
+ if not self.stop_segment or not self.stop_wal_file:
+ LOG.error("Unable to determine ending WAL file/segment")
+ return False
+ if not self.label:
+ LOG.error("No backup label found")
+ return False
+ return True
+
+
+class PgBasebackupIncremental(PgBasebackup):
+ """Incremental backup/restore for PostgreSQL.
+
+ To restore an incremental backup from a previous backup, in PostgreSQL,
+ is effectively to replay the WAL entries to a designated point in time.
+ All that is required is the most recent base backup, and all WAL files
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.parent_location = kwargs.pop('parent_location', '')
+ self.parent_checksum = kwargs.pop('parent_checksum', '')
+
+ super(PgBasebackupIncremental, self).__init__(*args, **kwargs)
+
+ self.incr_restore_cmd = f'tar -xzf - -C {self.wal_archive_dir}'
+
+ def pre_backup(self):
+ with psql_util.PostgresConnection('postgres') as conn:
+ self.start_segment = conn.query(
+ f"SELECT pg_start_backup('{self.filename}', false, false)"
+ )[0][0]
+ self.start_wal_file = conn.query(
+ f"SELECT pg_walfile_name('{self.start_segment}')")[0][0]
+ self.stop_segment = conn.query(
+ "SELECT * FROM pg_stop_backup(false, true)")[0][0]
+
+ # We have to hack this because self.command is
+ # initialized in the base class before we get here, which is
+ # when we will know exactly what WAL files we want to archive
+ self.command = self._cmd()
+
+ def _cmd(self):
+ wal_file_list = self.get_wal_files(backup_pos=1)
+ cmd = (f'tar -czf - -C {self.wal_archive_dir} '
+ f'{" ".join(wal_file_list)}')
+ return cmd + self.encrypt_cmd
+
+ def get_metadata(self):
+ _meta = super(PgBasebackupIncremental, self).get_metadata()
+ _meta.update({
+ 'parent_location': self.parent_location,
+ 'parent_checksum': self.parent_checksum,
+ })
+ return _meta
+
+ def incremental_restore_cmd(self, incr=False):
+ cmd = self.restore_command
+ if incr:
+ cmd = self.incr_restore_cmd
+ return self.decrypt_cmd + cmd
+
+ def incremental_restore(self, location, checksum):
+ """Perform incremental restore.
+
+ For the child backups, restore the wal files to wal archive dir.
+ For the base backup, restore to datadir.
+ """
+ metadata = self.storage.load_metadata(location, checksum)
+ if 'parent_location' in metadata:
+ LOG.info("Restoring parent: %(parent_location)s, "
+ "checksum: %(parent_checksum)s.", metadata)
+
+ parent_location = metadata['parent_location']
+ parent_checksum = metadata['parent_checksum']
+
+ # Restore parents recursively so backup are applied sequentially
+ self.incremental_restore(parent_location, parent_checksum)
+
+ command = self.incremental_restore_cmd(incr=True)
+ else:
+ # For the parent base backup, revert to the default restore cmd
+ LOG.info("Restoring back to full backup.")
+ command = self.incremental_restore_cmd(incr=False)
+
+ self.restore_content_length += self.unpack(location, checksum, command)
+
+ def run_restore(self):
+ """Run incremental restore."""
+ LOG.debug('Running incremental restore')
+ self.incremental_restore(self.location, self.checksum)
+ return self.restore_content_length
diff --git a/backup/install.sh b/backup/install.sh
index ad1c2e4a..19177baf 100755
--- a/backup/install.sh
+++ b/backup/install.sh
@@ -1,6 +1,7 @@
#!/usr/bin/env bash
set -e
+export APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1
APTOPTS="-y -qq --no-install-recommends --allow-unauthenticated"
case "$1" in
@@ -8,17 +9,21 @@ case "$1" in
curl -sSL https://repo.percona.com/apt/percona-release_latest.$(lsb_release -sc)_all.deb -o percona-release.deb
dpkg -i percona-release.deb
percona-release enable-only tools release
- apt-get update
apt-get install $APTOPTS percona-xtrabackup-$2
apt-get clean
;;
"mariadb")
apt-key adv --fetch-keys 'https://mariadb.org/mariadb_release_signing_key.asc'
add-apt-repository "deb [arch=amd64] http://mirror2.hs-esslingen.de/mariadb/repo/10.4/ubuntu $(lsb_release -cs) main"
- apt-get update
apt-get install $APTOPTS mariadb-backup
apt-get clean
;;
+"postgresql")
+ apt-key adv --fetch-keys 'https://www.postgresql.org/media/keys/ACCC4CF8.asc'
+ add-apt-repository "deb [arch=amd64] http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main"
+ apt-get install $APTOPTS postgresql-client-12
+ apt-get clean
+ ;;
*)
echo "datastore $1 not supported"
exit 1
diff --git a/backup/main.py b/backup/main.py
index c52becbf..a42dc4ba 100644
--- a/backup/main.py
+++ b/backup/main.py
@@ -36,13 +36,14 @@ cli_opts = [
cfg.StrOpt(
'driver',
default='innobackupex',
- choices=['innobackupex', 'xtrabackup', 'mariabackup']
+ choices=['innobackupex', 'mariabackup', 'pg_basebackup']
),
cfg.BoolOpt('backup'),
cfg.StrOpt('backup-encryption-key'),
cfg.StrOpt('db-user'),
cfg.StrOpt('db-password'),
cfg.StrOpt('db-host'),
+ cfg.StrOpt('db-datadir'),
cfg.StrOpt('os-token'),
cfg.StrOpt('os-auth-url'),
cfg.StrOpt('os-tenant-id'),
@@ -57,6 +58,7 @@ cli_opts = [
help='It is up to the storage driver to decide to validate the '
'checksum or not. '
),
+ cfg.StrOpt('pg-wal-archive-dir'),
]
driver_mapping = {
@@ -64,6 +66,8 @@ driver_mapping = {
'innobackupex_inc': 'backup.drivers.innobackupex.InnoBackupExIncremental',
'mariabackup': 'backup.drivers.mariabackup.MariaBackup',
'mariabackup_inc': 'backup.drivers.mariabackup.MariaBackupIncremental',
+ 'pg_basebackup': 'backup.drivers.postgres.PgBasebackup',
+ 'pg_basebackup_inc': 'backup.drivers.postgres.PgBasebackupIncremental',
}
storage_mapping = {
'swift': 'backup.storage.swift.SwiftStorage',
@@ -72,6 +76,7 @@ storage_mapping = {
def stream_backup_to_storage(runner_cls, storage):
parent_metadata = {}
+ extra_params = {}
if CONF.incremental:
if not CONF.parent_location:
@@ -88,8 +93,13 @@ def stream_backup_to_storage(runner_cls, storage):
}
)
+ if CONF.pg_wal_archive_dir:
+ extra_params['wal_archive_dir'] = CONF.pg_wal_archive_dir
+
+ extra_params.update(parent_metadata)
+
try:
- with runner_cls(filename=CONF.backup_id, **parent_metadata) as bkup:
+ with runner_cls(filename=CONF.backup_id, **extra_params) as bkup:
checksum, location = storage.save(
bkup,
metadata=CONF.swift_extra_metadata,
@@ -103,13 +113,19 @@ def stream_backup_to_storage(runner_cls, storage):
def stream_restore_from_storage(runner_cls, storage):
- lsn = ""
+ params = {
+ 'storage': storage,
+ 'location': CONF.restore_from,
+ 'checksum': CONF.restore_checksum,
+ 'wal_archive_dir': CONF.pg_wal_archive_dir,
+ 'lsn': None
+ }
+
if storage.is_incremental_backup(CONF.restore_from):
- lsn = storage.get_backup_lsn(CONF.restore_from)
+ params['lsn'] = storage.get_backup_lsn(CONF.restore_from)
try:
- runner = runner_cls(storage=storage, location=CONF.restore_from,
- checksum=CONF.restore_checksum, lsn=lsn)
+ runner = runner_cls(**params)
restore_size = runner.restore()
LOG.info('Restore successfully, restore_size: %s', restore_size)
except Exception as err:
diff --git a/backup/requirements.txt b/backup/requirements.txt
index 38358bd3..34b90614 100644
--- a/backup/requirements.txt
+++ b/backup/requirements.txt
@@ -2,5 +2,7 @@ oslo.config!=4.3.0,!=4.4.0;python_version>='3.0' # Apache-2.0
oslo.log;python_version>='3.0' # Apache-2.0
oslo.utils!=3.39.1,!=3.40.0,!=3.40.1;python_version>='3.0' # Apache-2.0
oslo.concurrency;python_version>='3.0' # Apache-2.0
+oslo.service!=1.28.1 # Apache-2.0
keystoneauth1 # Apache-2.0
python-swiftclient # Apache-2.0
+psycopg2-binary>=2.6.2 # LGPL/ZPL
diff --git a/backup/storage/swift.py b/backup/storage/swift.py
index 3930e68a..8c60cb56 100644
--- a/backup/storage/swift.py
+++ b/backup/storage/swift.py
@@ -185,7 +185,7 @@ class SwiftStorage(base.Storage):
for key, value in metadata.items():
headers[_set_attr(key)] = value
- LOG.debug('Metadata headers: %s', headers)
+ LOG.info('Metadata headers: %s', headers)
if large_object:
manifest_data = json.dumps(segment_results)
LOG.info('Creating the SLO manifest file, manifest content: %s',
@@ -212,8 +212,8 @@ class SwiftStorage(base.Storage):
headers=headers)
# Delete the old segment file that was copied
- LOG.debug('Deleting the old segment file %s.',
- stream_reader.first_segment)
+ LOG.info('Deleting the old segment file %s.',
+ stream_reader.first_segment)
self.client.delete_object(container,
stream_reader.first_segment)
@@ -288,7 +288,7 @@ class SwiftStorage(base.Storage):
return False
def get_backup_lsn(self, location):
- """Get the backup LSN."""
+ """Get the backup LSN if exists."""
_, container, filename = self._explodeLocation(location)
headers = self.client.head_object(container, filename)
return headers.get('x-object-meta-lsn')
diff --git a/backup/utils/__init__.py b/backup/utils/__init__.py
new file mode 100644
index 00000000..6c942335
--- /dev/null
+++ b/backup/utils/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Catalyst Cloud
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from oslo_service import loopingcall
+
+
+def build_polling_task(retriever, condition=lambda value: value,
+ sleep_time=1, time_out=0, initial_delay=0):
+ """Run a function in a loop with backoff on error.
+
+ The condition function runs based on the retriever function result.
+ """
+
+ def poll_and_check():
+ obj = retriever()
+ if condition(obj):
+ raise loopingcall.LoopingCallDone(retvalue=obj)
+
+ call = loopingcall.BackOffLoopingCall(f=poll_and_check)
+ return call.start(initial_delay=initial_delay,
+ starting_interval=sleep_time,
+ max_interval=30, timeout=time_out)
+
+
+def poll_until(retriever, condition=lambda value: value,
+ sleep_time=3, time_out=0, initial_delay=0):
+ """Retrieves object until it passes condition, then returns it.
+
+ If time_out_limit is passed in, PollTimeOut will be raised once that
+ amount of time is eclipsed.
+
+ """
+ task = build_polling_task(retriever, condition=condition,
+ sleep_time=sleep_time, time_out=time_out,
+ initial_delay=initial_delay)
+ return task.wait()
diff --git a/backup/utils/postgresql.py b/backup/utils/postgresql.py
new file mode 100644
index 00000000..033652f0
--- /dev/null
+++ b/backup/utils/postgresql.py
@@ -0,0 +1,53 @@
+# Copyright 2020 Catalyst Cloud
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import psycopg2
+
+
+class PostgresConnection(object):
+ def __init__(self, user, password='', host='localhost', port=5432):
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+
+ self.connect_str = (f"user='{self.user}' password='{self.password}' "
+ f"host='{self.host}' port='{self.port}'")
+
+ def __enter__(self, autocommit=False):
+ self.conn = psycopg2.connect(self.connect_str)
+ self.conn.autocommit = autocommit
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.conn.close()
+
+ def execute(self, statement, identifiers=None, data_values=None):
+ """Execute a non-returning statement."""
+ self._execute_stmt(statement, identifiers, data_values, False)
+
+ def query(self, query, identifiers=None, data_values=None):
+ """Execute a query and return the result set."""
+ return self._execute_stmt(query, identifiers, data_values, True)
+
+ def _execute_stmt(self, statement, identifiers, data_values, fetch):
+ cmd = self._bind(statement, identifiers)
+ with self.conn.cursor() as cursor:
+ cursor.execute(cmd, data_values)
+ if fetch:
+ return cursor.fetchall()
+
+ def _bind(self, statement, identifiers):
+ if identifiers:
+ return statement.format(*identifiers)
+ return statement