summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Lewandowski <marcin.lewandowski@intel.com>2022-03-22 13:28:14 +0100
committerMarcin Lewandowski <marcin.lewandowski@intel.com>2022-03-23 14:21:48 +0100
commita6ba72c2e6d0ce49f904556e9c1d1b765afd9e3a (patch)
tree5a9339f938aa1910dd2636d080e6733becb42106
parent85af4e453e647088fda50c3108dfd00e9753af3a (diff)
downloadtftpy-a6ba72c2e6d0ce49f904556e9c1d1b765afd9e3a.tar.gz
Fix race condition when waiting for ACK
TFTPy is designed in a way that socket timeout is used to calculate timeout when waiting for packet. During that time another unexpected packet may arrive. After that the socket operation is restarted and timeout is calculated from start. This might be a problem because both sides have timeout and these timeout may be different or one host may be significantly faster that another. In such situation the timeout will be never triggered as another host will always retransmit his packet faster. For most cases it does not matter because TFTP is always responding to packet sent and transmission may continue. The only one exception is no response to duplicate ACK. It is necessary to prevent Sorcerer's Apprentice Syndrome. This patch introduces additional exception TftpTimeoutExpectACK raised when reaching timeout during waiting on ACK of current block but receiving duplicate ACK of previous block.
-rw-r--r--t/test.py51
-rw-r--r--tftpy/TftpContexts.py6
-rw-r--r--tftpy/TftpServer.py5
-rw-r--r--tftpy/TftpShared.py7
-rw-r--r--tftpy/TftpStates.py7
5 files changed, 74 insertions, 2 deletions
diff --git a/t/test.py b/t/test.py
index 6ef27f7..22502ab 100644
--- a/t/test.py
+++ b/t/test.py
@@ -316,6 +316,57 @@ class TestTftpyState(unittest.TestCase):
finalstate = serverstate.state.handle(ack, raddress, rport)
self.assertTrue(finalstate is None)
+ def testServerTimeoutExpectACK(self):
+ raddress = "127.0.0.2"
+ rport = 10000
+ timeout = 5
+ root = os.path.dirname(os.path.abspath(__file__))
+ # Testing without the dyn_func_file set.
+ serverstate = tftpy.TftpContexts.TftpContextServer(
+ raddress, rport, timeout, root
+ )
+
+ self.assertTrue(isinstance(serverstate, tftpy.TftpContexts.TftpContextServer))
+
+ rrq = tftpy.TftpPacketTypes.TftpPacketRRQ()
+ rrq.filename = "640KBFILE"
+ rrq.mode = "octet"
+ rrq.options = {}
+
+ # Start the download.
+ serverstate.start(rrq.encode().buffer)
+
+ ack = tftpy.TftpPacketTypes.TftpPacketACK()
+ ack.blocknumber = 1
+
+ # Server expects ACK at the beginning of transmission
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+
+ # Receive first ACK for block 1, next block expected is 2
+ serverstate.state = serverstate.state.handle(ack, raddress, rport)
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+ self.assertEqual(serverstate.state.context.next_block, 2)
+
+ # Receive duplicate ACK for block 1, next block expected is still 2
+ serverstate.state = serverstate.state.handle(ack, raddress, rport)
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+ self.assertEqual(serverstate.state.context.next_block, 2)
+
+ # Receive duplicate ACK for block 1 after timeout for resending block 2
+ serverstate.state.context.metrics.last_dat_time -= 10 # Simulate 10 seconds time warp
+ self.assertRaises(
+ tftpy.TftpTimeoutExpectACK, serverstate.state.handle, ack, raddress, rport
+ )
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+
def testServerNoOptionsSubdir(self):
raddress = "127.0.0.2"
rport = 10000
diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py
index 0a82c8d..36cd7b3 100644
--- a/tftpy/TftpContexts.py
+++ b/tftpy/TftpContexts.py
@@ -43,6 +43,7 @@ class TftpMetrics:
self.start_time = 0
self.end_time = 0
self.duration = 0
+ self.last_dat_time = 0
# Rates
self.bps = 0
self.kbps = 0
@@ -112,6 +113,9 @@ class TftpContext:
self.last_pkt = None
# Count the number of retry attempts.
self.retry_count = 0
+ # Flag to signal timeout error when waiting for ACK of the current block
+ # and at the same time receiving duplicate ACK of previous block
+ self.timeout_expectACK = False
def getBlocksize(self):
"""Fetch the current blocksize for this session."""
@@ -127,6 +131,8 @@ class TftpContext:
"""Compare current time with last_update time, and raise an exception
if we're over the timeout time."""
log.debug("checking for timeout on session %s", self)
+ if self.timeout_expectACK:
+ raise TftpTimeout("Timeout waiting for traffic")
if now - self.last_update > self.timeout:
raise TftpTimeout("Timeout waiting for traffic")
diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py
index 56e7944..cb01d18 100644
--- a/tftpy/TftpServer.py
+++ b/tftpy/TftpServer.py
@@ -180,6 +180,8 @@ class TftpServer(TftpSession):
)
try:
self.sessions[key].start(buffer)
+ except TftpTimeoutExpectACK:
+ self.sessions[key].timeout_expectACK = True
except TftpException as err:
deletion_list.append(key)
log.error(
@@ -199,11 +201,14 @@ class TftpServer(TftpSession):
for key in self.sessions:
if readysock == self.sessions[key].sock:
log.debug("Matched input to session key %s" % key)
+ self.sessions[key].timeout_expectACK = False
try:
self.sessions[key].cycle()
if self.sessions[key].state is None:
log.info("Successful transfer.")
deletion_list.append(key)
+ except TftpTimeoutExpectACK:
+ self.sessions[key].timeout_expectACK = True
except TftpException as err:
deletion_list.append(key)
log.error(
diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py
index d464d4d..6c129b9 100644
--- a/tftpy/TftpShared.py
+++ b/tftpy/TftpShared.py
@@ -52,6 +52,13 @@ class TftpTimeout(TftpException):
pass
+class TftpTimeoutExpectACK(TftpTimeout):
+ """This class represents a timeout error when waiting for ACK of the current block
+ and receiving duplicate ACK for previous block from the other end."""
+
+ pass
+
+
class TftpFileNotFoundError(TftpException):
"""This class represents an error condition where we received a file
not found error."""
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py
index 15d52f9..ca303c3 100644
--- a/tftpy/TftpStates.py
+++ b/tftpy/TftpStates.py
@@ -12,6 +12,7 @@ error, in which case a TftpException is returned instead."""
import logging
import os
+import time
from .TftpPacketTypes import *
from .TftpShared import *
@@ -92,8 +93,6 @@ class TftpState:
blocknumber = self.context.next_block
# Test hook
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
- import time
-
log.debug("Deliberately delaying 10 seconds...")
time.sleep(10)
dat = None
@@ -111,6 +110,7 @@ class TftpState:
self.context.sock.sendto(
dat.encode().buffer, (self.context.host, self.context.tidport)
)
+ self.context.metrics.last_dat_time = time.time()
if self.context.packethook:
self.context.packethook(dat)
self.context.last_pkt = dat
@@ -465,6 +465,9 @@ class TftpStateExpectACK(TftpState):
elif pkt.blocknumber < self.context.next_block:
log.warning("Received duplicate ACK for block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt)
+ if self.context.metrics.last_dat_time > 0:
+ if time.time() - self.context.metrics.last_dat_time > self.context.timeout:
+ raise TftpTimeoutExpectACK("Timeout waiting for ACK for block %d" % self.context.next_block)
else:
log.warning(