diff options
author | Marcin Lewandowski <marcin.lewandowski@intel.com> | 2022-03-22 13:28:14 +0100 |
---|---|---|
committer | Marcin Lewandowski <marcin.lewandowski@intel.com> | 2022-03-23 14:21:48 +0100 |
commit | a6ba72c2e6d0ce49f904556e9c1d1b765afd9e3a (patch) | |
tree | 5a9339f938aa1910dd2636d080e6733becb42106 | |
parent | 85af4e453e647088fda50c3108dfd00e9753af3a (diff) | |
download | tftpy-a6ba72c2e6d0ce49f904556e9c1d1b765afd9e3a.tar.gz |
Fix race condition when waiting for ACK
TFTPy is designed in a way that socket timeout is used to calculate timeout
when waiting for packet. During that time another unexpected packet may arrive.
After that the socket operation is restarted and timeout is calculated from start.
This might be a problem because both sides have timeout and these timeout
may be different or one host may be significantly faster that another.
In such situation the timeout will be never triggered as another host will
always retransmit his packet faster.
For most cases it does not matter because TFTP is always responding to packet
sent and transmission may continue. The only one exception is no response
to duplicate ACK. It is necessary to prevent Sorcerer's Apprentice Syndrome.
This patch introduces additional exception TftpTimeoutExpectACK raised
when reaching timeout during waiting on ACK of current block but receiving
duplicate ACK of previous block.
-rw-r--r-- | t/test.py | 51 | ||||
-rw-r--r-- | tftpy/TftpContexts.py | 6 | ||||
-rw-r--r-- | tftpy/TftpServer.py | 5 | ||||
-rw-r--r-- | tftpy/TftpShared.py | 7 | ||||
-rw-r--r-- | tftpy/TftpStates.py | 7 |
5 files changed, 74 insertions, 2 deletions
@@ -316,6 +316,57 @@ class TestTftpyState(unittest.TestCase): finalstate = serverstate.state.handle(ack, raddress, rport) self.assertTrue(finalstate is None) + def testServerTimeoutExpectACK(self): + raddress = "127.0.0.2" + rport = 10000 + timeout = 5 + root = os.path.dirname(os.path.abspath(__file__)) + # Testing without the dyn_func_file set. + serverstate = tftpy.TftpContexts.TftpContextServer( + raddress, rport, timeout, root + ) + + self.assertTrue(isinstance(serverstate, tftpy.TftpContexts.TftpContextServer)) + + rrq = tftpy.TftpPacketTypes.TftpPacketRRQ() + rrq.filename = "640KBFILE" + rrq.mode = "octet" + rrq.options = {} + + # Start the download. + serverstate.start(rrq.encode().buffer) + + ack = tftpy.TftpPacketTypes.TftpPacketACK() + ack.blocknumber = 1 + + # Server expects ACK at the beginning of transmission + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + + # Receive first ACK for block 1, next block expected is 2 + serverstate.state = serverstate.state.handle(ack, raddress, rport) + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + self.assertEqual(serverstate.state.context.next_block, 2) + + # Receive duplicate ACK for block 1, next block expected is still 2 + serverstate.state = serverstate.state.handle(ack, raddress, rport) + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + self.assertEqual(serverstate.state.context.next_block, 2) + + # Receive duplicate ACK for block 1 after timeout for resending block 2 + serverstate.state.context.metrics.last_dat_time -= 10 # Simulate 10 seconds time warp + self.assertRaises( + tftpy.TftpTimeoutExpectACK, serverstate.state.handle, ack, raddress, rport + ) + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + def testServerNoOptionsSubdir(self): raddress = "127.0.0.2" rport = 10000 diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index 0a82c8d..36cd7b3 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -43,6 +43,7 @@ class TftpMetrics: self.start_time = 0 self.end_time = 0 self.duration = 0 + self.last_dat_time = 0 # Rates self.bps = 0 self.kbps = 0 @@ -112,6 +113,9 @@ class TftpContext: self.last_pkt = None # Count the number of retry attempts. self.retry_count = 0 + # Flag to signal timeout error when waiting for ACK of the current block + # and at the same time receiving duplicate ACK of previous block + self.timeout_expectACK = False def getBlocksize(self): """Fetch the current blocksize for this session.""" @@ -127,6 +131,8 @@ class TftpContext: """Compare current time with last_update time, and raise an exception if we're over the timeout time.""" log.debug("checking for timeout on session %s", self) + if self.timeout_expectACK: + raise TftpTimeout("Timeout waiting for traffic") if now - self.last_update > self.timeout: raise TftpTimeout("Timeout waiting for traffic") diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 56e7944..cb01d18 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -180,6 +180,8 @@ class TftpServer(TftpSession): ) try: self.sessions[key].start(buffer) + except TftpTimeoutExpectACK: + self.sessions[key].timeout_expectACK = True except TftpException as err: deletion_list.append(key) log.error( @@ -199,11 +201,14 @@ class TftpServer(TftpSession): for key in self.sessions: if readysock == self.sessions[key].sock: log.debug("Matched input to session key %s" % key) + self.sessions[key].timeout_expectACK = False try: self.sessions[key].cycle() if self.sessions[key].state is None: log.info("Successful transfer.") deletion_list.append(key) + except TftpTimeoutExpectACK: + self.sessions[key].timeout_expectACK = True except TftpException as err: deletion_list.append(key) log.error( diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py index d464d4d..6c129b9 100644 --- a/tftpy/TftpShared.py +++ b/tftpy/TftpShared.py @@ -52,6 +52,13 @@ class TftpTimeout(TftpException): pass +class TftpTimeoutExpectACK(TftpTimeout): + """This class represents a timeout error when waiting for ACK of the current block + and receiving duplicate ACK for previous block from the other end.""" + + pass + + class TftpFileNotFoundError(TftpException): """This class represents an error condition where we received a file not found error.""" diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 15d52f9..ca303c3 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -12,6 +12,7 @@ error, in which case a TftpException is returned instead.""" import logging import os +import time from .TftpPacketTypes import * from .TftpShared import * @@ -92,8 +93,6 @@ class TftpState: blocknumber = self.context.next_block # Test hook if DELAY_BLOCK and DELAY_BLOCK == blocknumber: - import time - log.debug("Deliberately delaying 10 seconds...") time.sleep(10) dat = None @@ -111,6 +110,7 @@ class TftpState: self.context.sock.sendto( dat.encode().buffer, (self.context.host, self.context.tidport) ) + self.context.metrics.last_dat_time = time.time() if self.context.packethook: self.context.packethook(dat) self.context.last_pkt = dat @@ -465,6 +465,9 @@ class TftpStateExpectACK(TftpState): elif pkt.blocknumber < self.context.next_block: log.warning("Received duplicate ACK for block %d" % pkt.blocknumber) self.context.metrics.add_dup(pkt) + if self.context.metrics.last_dat_time > 0: + if time.time() - self.context.metrics.last_dat_time > self.context.timeout: + raise TftpTimeoutExpectACK("Timeout waiting for ACK for block %d" % self.context.next_block) else: log.warning( |