summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsebres <serg.brester@sebres.de>2022-09-16 19:11:53 +0200
committersebres <serg.brester@sebres.de>2022-09-16 19:14:50 +0200
commit94dac78afebc2a4a676727527d99e70d970b061a (patch)
tree9a97edba1e7a9e97f59940fb4468d227f8a42c69
parent8dccf099e4852483a4e6245f62103499dededae8 (diff)
parent485c50228a4dce38b1766173d3504b174475fd18 (diff)
downloadfail2ban-0.11.tar.gz
Merge branch '0.10' into 0.110.11
(conflicts resolved)
-rwxr-xr-xfail2ban/client/fail2banclient.py57
-rw-r--r--fail2ban/client/fail2banserver.py28
-rw-r--r--fail2ban/server/database.py129
-rw-r--r--fail2ban/server/transmitter.py2
-rw-r--r--fail2ban/tests/fail2banclienttestcase.py33
5 files changed, 156 insertions, 93 deletions
diff --git a/fail2ban/client/fail2banclient.py b/fail2ban/client/fail2banclient.py
index c72208cd..f3b0f7b2 100755
--- a/fail2ban/client/fail2banclient.py
+++ b/fail2ban/client/fail2banclient.py
@@ -175,9 +175,13 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return [["server-stream", stream], ['server-status']]
+ def _set_server(self, s):
+ self._server = s
+
##
def __startServer(self, background=True):
from .fail2banserver import Fail2banServer
+ # read configuration here (in client only, in server we do that in the config-thread):
stream = self.__prepareStartServer()
self._alive = True
if not stream:
@@ -192,16 +196,19 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return False
else:
# In foreground mode we should make server/client communication in different threads:
- th = Thread(target=Fail2banClient.__processStartStreamAfterWait, args=(self, stream, False))
- th.daemon = True
- th.start()
+ phase = dict()
+ self.configureServer(phase=phase, stream=stream)
# Mark current (main) thread as daemon:
self.daemon = True
# Start server direct here in main thread (not fork):
- self._server = Fail2banServer.startServerDirect(self._conf, False)
-
+ self._server = Fail2banServer.startServerDirect(self._conf, False, self._set_server)
+ if not phase.get('done', False):
+ if self._server: # pragma: no cover
+ self._server.quit()
+ self._server = None
+ exit(255)
except ExitException: # pragma: no cover
- pass
+ raise
except Exception as e: # pragma: no cover
output("")
logSys.error("Exception while starting server " + ("background" if background else "foreground"))
@@ -214,23 +221,39 @@ class Fail2banClient(Fail2banCmdLine, Thread):
return True
##
- def configureServer(self, nonsync=True, phase=None):
+ def configureServer(self, nonsync=True, phase=None, stream=None):
# if asynchronous start this operation in the new thread:
if nonsync:
- th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase))
+ if phase is not None:
+ # event for server ready flag:
+ def _server_ready():
+ phase['start-ready'] = True
+ logSys.log(5, ' server phase %s', phase)
+ # notify waiting thread if server really ready
+ self._conf['onstart'] = _server_ready
+ th = Thread(target=Fail2banClient.configureServer, args=(self, False, phase, stream))
th.daemon = True
- return th.start()
+ th.start()
+ # if we need to read configuration stream:
+ if stream is None and phase is not None:
+ # wait, do not continue if configuration is not 100% valid:
+ Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001)
+ logSys.log(5, ' server phase %s', phase)
+ if not phase.get('start', False):
+ raise ServerExecutionException('Async configuration of server failed')
+ return True
# prepare: read config, check configuration is valid, etc.:
if phase is not None:
phase['start'] = True
logSys.log(5, ' client phase %s', phase)
- stream = self.__prepareStartServer()
+ if stream is None:
+ stream = self.__prepareStartServer()
if phase is not None:
phase['ready'] = phase['start'] = (True if stream else False)
logSys.log(5, ' client phase %s', phase)
if not stream:
return False
- # wait a litle bit for phase "start-ready" before enter active waiting:
+ # wait a little bit for phase "start-ready" before enter active waiting:
if phase is not None:
Utils.wait_for(lambda: phase.get('start-ready', None) is not None, 0.5, 0.001)
phase['configure'] = (True if stream else False)
@@ -321,13 +344,14 @@ class Fail2banClient(Fail2banCmdLine, Thread):
def __processStartStreamAfterWait(self, *args):
+ ret = False
try:
# Wait for the server to start
if not self.__waitOnServer(): # pragma: no cover
logSys.error("Could not find server, waiting failed")
return False
# Configure the server
- self.__processCmd(*args)
+ ret = self.__processCmd(*args)
except ServerExecutionException as e: # pragma: no cover
if self._conf["verbose"] > 1:
logSys.exception(e)
@@ -336,10 +360,11 @@ class Fail2banClient(Fail2banCmdLine, Thread):
"remove " + self._conf["socket"] + ". If "
"you used fail2ban-client to start the "
"server, adding the -x option will do it")
- if self._server:
- self._server.quit()
- return False
- return True
+
+ if not ret and self._server: # stop on error (foreground, config read in another thread):
+ self._server.quit()
+ self._server = None
+ return ret
def __waitOnServer(self, alive=True, maxtime=None):
if maxtime is None:
diff --git a/fail2ban/client/fail2banserver.py b/fail2ban/client/fail2banserver.py
index d94d13ff..eee78d5f 100644
--- a/fail2ban/client/fail2banserver.py
+++ b/fail2ban/client/fail2banserver.py
@@ -44,7 +44,7 @@ class Fail2banServer(Fail2banCmdLine):
# Start the Fail2ban server in background/foreground (daemon mode or not).
@staticmethod
- def startServerDirect(conf, daemon=True):
+ def startServerDirect(conf, daemon=True, setServer=None):
logSys.debug(" direct starting of server in %s, deamon: %s", os.getpid(), daemon)
from ..server.server import Server
server = None
@@ -52,6 +52,10 @@ class Fail2banServer(Fail2banCmdLine):
# Start it in foreground (current thread, not new process),
# server object will internally fork self if daemon is True
server = Server(daemon)
+ # notify caller - set server handle:
+ if setServer:
+ setServer(server)
+ # run:
server.start(conf["socket"],
conf["pidfile"], conf["force"],
conf=conf)
@@ -63,6 +67,10 @@ class Fail2banServer(Fail2banCmdLine):
if conf["verbose"] > 1:
logSys.exception(e2)
raise
+ finally:
+ # notify waiting thread server ready resp. done (background execution, error case, etc):
+ if conf.get('onstart'):
+ conf['onstart']()
return server
@@ -179,27 +187,15 @@ class Fail2banServer(Fail2banCmdLine):
# Start new thread with client to read configuration and
# transfer it to the server:
cli = self._Fail2banClient()
+ cli._conf = self._conf
phase = dict()
logSys.debug('Configure via async client thread')
cli.configureServer(phase=phase)
- # wait, do not continue if configuration is not 100% valid:
- Utils.wait_for(lambda: phase.get('ready', None) is not None, self._conf["timeout"], 0.001)
- logSys.log(5, ' server phase %s', phase)
- if not phase.get('start', False):
- raise ServerExecutionException('Async configuration of server failed')
- # event for server ready flag:
- def _server_ready():
- phase['start-ready'] = True
- logSys.log(5, ' server phase %s', phase)
- # notify waiting thread if server really ready
- self._conf['onstart'] = _server_ready
# Start server, daemonize it, etc.
pid = os.getpid()
- server = Fail2banServer.startServerDirect(self._conf, background)
- # notify waiting thread server ready resp. done (background execution, error case, etc):
- if not nonsync:
- _server_ready()
+ server = Fail2banServer.startServerDirect(self._conf, background,
+ cli._set_server if cli else None)
# If forked - just exit other processes
if pid != os.getpid(): # pragma: no cover
os._exit(0)
diff --git a/fail2ban/server/database.py b/fail2ban/server/database.py
index ed736a7a..59eeb8fd 100644
--- a/fail2ban/server/database.py
+++ b/fail2ban/server/database.py
@@ -104,7 +104,11 @@ def commitandrollback(f):
def wrapper(self, *args, **kwargs):
with self._lock: # Threading lock
with self._db: # Auto commit and rollback on exception
- return f(self, self._db.cursor(), *args, **kwargs)
+ cur = self._db.cursor()
+ try:
+ return f(self, cur, *args, **kwargs)
+ finally:
+ cur.close()
return wrapper
@@ -253,7 +257,7 @@ class Fail2BanDb(object):
self.repairDB()
else:
version = cur.fetchone()[0]
- if version < Fail2BanDb.__version__:
+ if version != Fail2BanDb.__version__:
newversion = self.updateDb(version)
if newversion == Fail2BanDb.__version__:
logSys.warning( "Database updated from '%r' to '%r'",
@@ -301,9 +305,11 @@ class Fail2BanDb(object):
try:
# backup
logSys.info("Trying to repair database %s", self._dbFilename)
- shutil.move(self._dbFilename, self._dbBackupFilename)
- logSys.info(" Database backup created: %s", self._dbBackupFilename)
-
+ if not os.path.isfile(self._dbBackupFilename):
+ shutil.move(self._dbFilename, self._dbBackupFilename)
+ logSys.info(" Database backup created: %s", self._dbBackupFilename)
+ elif os.path.isfile(self._dbFilename):
+ os.remove(self._dbFilename)
# first try to repair using dump/restore in order
Utils.executeCmd((r"""f2b_db=$0; f2b_dbbk=$1; sqlite3 "$f2b_dbbk" ".dump" | sqlite3 "$f2b_db" """,
self._dbFilename, self._dbBackupFilename))
@@ -415,7 +421,7 @@ class Fail2BanDb(object):
logSys.error("Failed to upgrade database '%s': %s",
self._dbFilename, e.args[0],
exc_info=logSys.getEffectiveLevel() <= 10)
- raise
+ self.repairDB()
@commitandrollback
def addJail(self, cur, jail):
@@ -789,7 +795,6 @@ class Fail2BanDb(object):
queryArgs.append(fromtime)
if overalljails or jail is None:
query += " GROUP BY ip ORDER BY timeofban DESC LIMIT 1"
- cur = self._db.cursor()
# repack iterator as long as in lock:
return list(cur.execute(query, queryArgs))
@@ -812,11 +817,9 @@ class Fail2BanDb(object):
query += " GROUP BY ip ORDER BY ip, timeofban DESC"
else:
query += " ORDER BY timeofban DESC LIMIT 1"
- cur = self._db.cursor()
return cur.execute(query, queryArgs)
- @commitandrollback
- def getCurrentBans(self, cur, jail=None, ip=None, forbantime=None, fromtime=None,
+ def getCurrentBans(self, jail=None, ip=None, forbantime=None, fromtime=None,
correctBanTime=True, maxmatches=None
):
"""Reads tickets (with merged info) currently affected from ban from the database.
@@ -828,57 +831,63 @@ class Fail2BanDb(object):
(and therefore endOfBan) of the ticket (normally it is ban-time of jail as maximum)
for all tickets with ban-time greater (or persistent).
"""
- if fromtime is None:
- fromtime = MyTime.time()
- tickets = []
- ticket = None
- if correctBanTime is True:
- correctBanTime = jail.getMaxBanTime() if jail is not None else None
- # don't change if persistent allowed:
- if correctBanTime == -1: correctBanTime = None
-
- for ticket in self._getCurrentBans(cur, jail=jail, ip=ip,
- forbantime=forbantime, fromtime=fromtime
- ):
- # can produce unpack error (database may return sporadical wrong-empty row):
- try:
- banip, timeofban, bantime, bancount, data = ticket
- # additionally check for empty values:
- if banip is None or banip == "": # pragma: no cover
- raise ValueError('unexpected value %r' % (banip,))
- # if bantime unknown (after upgrade-db from earlier version), just use min known ban-time:
- if bantime == -2: # todo: remove it in future version
- bantime = jail.actions.getBanTime() if jail is not None else (
- correctBanTime if correctBanTime else 600)
- elif correctBanTime and correctBanTime >= 0:
- # if persistent ban (or greater as max), use current max-bantime of the jail:
- if bantime == -1 or bantime > correctBanTime:
- bantime = correctBanTime
- # after correction check the end of ban again:
- if bantime != -1 and timeofban + bantime <= fromtime:
- # not persistent and too old - ignore it:
- logSys.debug("ignore ticket (with new max ban-time %r): too old %r <= %r, ticket: %r",
- bantime, timeofban + bantime, fromtime, ticket)
+ cur = self._db.cursor()
+ try:
+ if fromtime is None:
+ fromtime = MyTime.time()
+ tickets = []
+ ticket = None
+ if correctBanTime is True:
+ correctBanTime = jail.getMaxBanTime() if jail is not None else None
+ # don't change if persistent allowed:
+ if correctBanTime == -1: correctBanTime = None
+
+ with self._lock:
+ bans = self._getCurrentBans(cur, jail=jail, ip=ip,
+ forbantime=forbantime, fromtime=fromtime
+ )
+ for ticket in bans:
+ # can produce unpack error (database may return sporadical wrong-empty row):
+ try:
+ banip, timeofban, bantime, bancount, data = ticket
+ # additionally check for empty values:
+ if banip is None or banip == "": # pragma: no cover
+ raise ValueError('unexpected value %r' % (banip,))
+ # if bantime unknown (after upgrade-db from earlier version), just use min known ban-time:
+ if bantime == -2: # todo: remove it in future version
+ bantime = jail.actions.getBanTime() if jail is not None else (
+ correctBanTime if correctBanTime else 600)
+ elif correctBanTime and correctBanTime >= 0:
+ # if persistent ban (or greater as max), use current max-bantime of the jail:
+ if bantime == -1 or bantime > correctBanTime:
+ bantime = correctBanTime
+ # after correction check the end of ban again:
+ if bantime != -1 and timeofban + bantime <= fromtime:
+ # not persistent and too old - ignore it:
+ logSys.debug("ignore ticket (with new max ban-time %r): too old %r <= %r, ticket: %r",
+ bantime, timeofban + bantime, fromtime, ticket)
+ continue
+ except ValueError as e: # pragma: no cover
+ logSys.debug("get current bans: ignore row %r - %s", ticket, e)
continue
- except ValueError as e: # pragma: no cover
- logSys.debug("get current bans: ignore row %r - %s", ticket, e)
- continue
- # logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data)
- ticket = FailTicket(banip, timeofban, data=data)
- # filter matches if expected (current count > as maxmatches specified):
- if maxmatches is None:
- maxmatches = self.maxMatches
- if maxmatches:
- matches = ticket.getMatches()
- if matches and len(matches) > maxmatches:
- ticket.setMatches(matches[-maxmatches:])
- else:
- ticket.setMatches(None)
- # logSys.debug('restored ticket: %r', ticket)
- ticket.setBanTime(bantime)
- ticket.setBanCount(bancount)
- if ip is not None: return ticket
- tickets.append(ticket)
+ # logSys.debug('restore ticket %r, %r, %r', banip, timeofban, data)
+ ticket = FailTicket(banip, timeofban, data=data)
+ # filter matches if expected (current count > as maxmatches specified):
+ if maxmatches is None:
+ maxmatches = self.maxMatches
+ if maxmatches:
+ matches = ticket.getMatches()
+ if matches and len(matches) > maxmatches:
+ ticket.setMatches(matches[-maxmatches:])
+ else:
+ ticket.setMatches(None)
+ # logSys.debug('restored ticket: %r', ticket)
+ ticket.setBanTime(bantime)
+ ticket.setBanCount(bancount)
+ if ip is not None: return ticket
+ tickets.append(ticket)
+ finally:
+ cur.close()
return tickets
diff --git a/fail2ban/server/transmitter.py b/fail2ban/server/transmitter.py
index 8e17d862..6de60f94 100644
--- a/fail2ban/server/transmitter.py
+++ b/fail2ban/server/transmitter.py
@@ -58,7 +58,7 @@ class Transmitter:
ret = self.__commandHandler(command)
ack = 0, ret
except Exception as e:
- logSys.warning("Command %r has failed. Received %r",
+ logSys.error("Command %r has failed. Received %r",
command, e,
exc_info=logSys.getEffectiveLevel()<=logging.DEBUG)
ack = 1, e
diff --git a/fail2ban/tests/fail2banclienttestcase.py b/fail2ban/tests/fail2banclienttestcase.py
index 86480f02..d7213010 100644
--- a/fail2ban/tests/fail2banclienttestcase.py
+++ b/fail2ban/tests/fail2banclienttestcase.py
@@ -491,6 +491,39 @@ class Fail2banClientServerBase(LogCaptureTestCase):
self.execCmd(FAILED, startparams, "~~unknown~cmd~failed~~")
self.execCmd(SUCCESS, startparams, "echo", "TEST-ECHO")
+ @with_tmpdir
+ @with_kill_srv
+ def testStartFailsInForeground(self, tmp):
+ if not server.Fail2BanDb: # pragma: no cover
+ raise unittest.SkipTest('Skip test because no database')
+ dbname = pjoin(tmp,"tmp.db")
+ db = server.Fail2BanDb(dbname)
+ # set inappropriate DB version to simulate an irreparable error by start:
+ cur = db._db.cursor()
+ cur.executescript("UPDATE fail2banDb SET version = 555")
+ cur.close()
+ # timeout (thread will stop foreground server):
+ startparams = _start_params(tmp, db=dbname, logtarget='INHERITED')
+ phase = {'stop': True}
+ def _stopTimeout(startparams, phase):
+ if not Utils.wait_for(lambda: not phase['stop'], MAX_WAITTIME):
+ # print('==== STOP ====')
+ self.execCmdDirect(startparams, 'stop')
+ th = Thread(
+ name="_TestCaseWorker",
+ target=_stopTimeout,
+ args=(startparams, phase)
+ )
+ th.start()
+ # test:
+ try:
+ self.execCmd(FAILED, ("-f",) + startparams, "start")
+ finally:
+ phase['stop'] = False
+ th.join()
+ self.assertLogged("Attempt to travel to future version of database",
+ "Exit with code 255", all=True)
+
class Fail2banClientTest(Fail2banClientServerBase):