diff options
author | Michael P. Soulier <msoulier@digitaltorque.ca> | 2022-04-18 10:24:25 -0400 |
---|---|---|
committer | Michael P. Soulier <msoulier@digitaltorque.ca> | 2022-04-18 10:24:25 -0400 |
commit | f0588430217f44730eb25db95f341d548a816488 (patch) | |
tree | 648747946ffcbcb6b71780e6370a5502f1673560 | |
parent | 57feb2c3416c59d6c0a2061d7bc2e56bc7c87ad5 (diff) | |
parent | a6ba72c2e6d0ce49f904556e9c1d1b765afd9e3a (diff) | |
download | tftpy-f0588430217f44730eb25db95f341d548a816488.tar.gz |
Merged PR 133 - handling duplicate ACK
-rw-r--r-- | t/test.py | 48 | ||||
-rw-r--r-- | tftpy/TftpContexts.py | 6 | ||||
-rw-r--r-- | tftpy/TftpServer.py | 5 | ||||
-rw-r--r-- | tftpy/TftpShared.py | 7 | ||||
-rw-r--r-- | tftpy/TftpStates.py | 6 |
5 files changed, 53 insertions, 19 deletions
@@ -316,9 +316,7 @@ class TestTftpyState(unittest.TestCase): finalstate = serverstate.state.handle(ack, raddress, rport) self.assertTrue(finalstate is None) - def testServerNoOptionsUnreliable(self): - log.debug("===> Running testcase testClientServerNoOptionsUnreliable") - tftpy.TftpStates.NETWORK_UNRELIABILITY = 1000 + def testServerTimeoutExpectACK(self): raddress = "127.0.0.2" rport = 10000 timeout = 5 @@ -337,23 +335,37 @@ class TestTftpyState(unittest.TestCase): # Start the download. serverstate.start(rrq.encode().buffer) - # At a 512 byte blocksize, this should be 1280 packets exactly. - for block in range(1, 1281): - # Should be in expectack state. - self.assertTrue( - isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) - ) - ack = tftpy.TftpPacketTypes.TftpPacketACK() - ack.blocknumber = block % 65536 - serverstate.state = serverstate.state.handle(ack, raddress, rport) - # The last DAT packet should be empty, indicating a completed - # transfer. ack = tftpy.TftpPacketTypes.TftpPacketACK() - ack.blocknumber = 1281 % 65536 - finalstate = serverstate.state.handle(ack, raddress, rport) - self.assertTrue(finalstate is None) - tftpy.TftpStates.NETWORK_UNRELIABILITY = 0 + ack.blocknumber = 1 + + # Server expects ACK at the beginning of transmission + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + + # Receive first ACK for block 1, next block expected is 2 + serverstate.state = serverstate.state.handle(ack, raddress, rport) + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + self.assertEqual(serverstate.state.context.next_block, 2) + + # Receive duplicate ACK for block 1, next block expected is still 2 + serverstate.state = serverstate.state.handle(ack, raddress, rport) + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + self.assertEqual(serverstate.state.context.next_block, 2) + + # Receive duplicate ACK for block 1 after timeout for resending block 2 + serverstate.state.context.metrics.last_dat_time -= 10 # Simulate 10 seconds time warp + self.assertRaises( + tftpy.TftpTimeoutExpectACK, serverstate.state.handle, ack, raddress, rport + ) + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) def testServerNoOptionsSubdir(self): raddress = "127.0.0.2" diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index d4f127c..ab0f9f8 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -43,6 +43,7 @@ class TftpMetrics: self.start_time = 0 self.end_time = 0 self.duration = 0 + self.last_dat_time = 0 # Rates self.bps = 0 self.kbps = 0 @@ -112,6 +113,9 @@ class TftpContext: self.last_pkt = None # Count the number of retry attempts. self.retry_count = 0 + # Flag to signal timeout error when waiting for ACK of the current block + # and at the same time receiving duplicate ACK of previous block + self.timeout_expectACK = False def getBlocksize(self): """Fetch the current blocksize for this session.""" @@ -127,6 +131,8 @@ class TftpContext: """Compare current time with last_update time, and raise an exception if we're over the timeout time.""" log.debug("checking for timeout on session %s", self) + if self.timeout_expectACK: + raise TftpTimeout("Timeout waiting for traffic") if now - self.last_update > self.timeout: raise TftpTimeout("Timeout waiting for traffic") diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 56e7944..cb01d18 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -180,6 +180,8 @@ class TftpServer(TftpSession): ) try: self.sessions[key].start(buffer) + except TftpTimeoutExpectACK: + self.sessions[key].timeout_expectACK = True except TftpException as err: deletion_list.append(key) log.error( @@ -199,11 +201,14 @@ class TftpServer(TftpSession): for key in self.sessions: if readysock == self.sessions[key].sock: log.debug("Matched input to session key %s" % key) + self.sessions[key].timeout_expectACK = False try: self.sessions[key].cycle() if self.sessions[key].state is None: log.info("Successful transfer.") deletion_list.append(key) + except TftpTimeoutExpectACK: + self.sessions[key].timeout_expectACK = True except TftpException as err: deletion_list.append(key) log.error( diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py index e727c94..5603ba6 100644 --- a/tftpy/TftpShared.py +++ b/tftpy/TftpShared.py @@ -57,6 +57,13 @@ class TftpTimeout(TftpException): pass +class TftpTimeoutExpectACK(TftpTimeout): + """This class represents a timeout error when waiting for ACK of the current block + and receiving duplicate ACK for previous block from the other end.""" + + pass + + class TftpFileNotFoundError(TftpException): """This class represents an error condition where we received a file not found error.""" diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 1f98682..f6b4b48 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -13,6 +13,7 @@ error, in which case a TftpException is returned instead.""" import logging import os import random +import time from .TftpPacketTypes import * from .TftpShared import * @@ -93,7 +94,6 @@ class TftpState: blocknumber = self.context.next_block # Test hook if DELAY_BLOCK and DELAY_BLOCK == blocknumber: - import time log.debug("Deliberately delaying 10 seconds...") time.sleep(10) dat = None @@ -115,6 +115,7 @@ class TftpState: self.context.sock.sendto( dat.encode().buffer, (self.context.host, self.context.tidport) ) + self.context.metrics.last_dat_time = time.time() if self.context.packethook: self.context.packethook(dat) self.context.last_pkt = dat @@ -474,6 +475,9 @@ class TftpStateExpectACK(TftpState): elif pkt.blocknumber < self.context.next_block: log.warning("Received duplicate ACK for block %d" % pkt.blocknumber) self.context.metrics.add_dup(pkt) + if self.context.metrics.last_dat_time > 0: + if time.time() - self.context.metrics.last_dat_time > self.context.timeout: + raise TftpTimeoutExpectACK("Timeout waiting for ACK for block %d" % self.context.next_block) else: log.warning( |