summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--t/test.py51
-rw-r--r--tftpy/TftpContexts.py6
-rw-r--r--tftpy/TftpServer.py5
-rw-r--r--tftpy/TftpShared.py7
-rw-r--r--tftpy/TftpStates.py7
5 files changed, 74 insertions, 2 deletions
diff --git a/t/test.py b/t/test.py
index 6ef27f7..22502ab 100644
--- a/t/test.py
+++ b/t/test.py
@@ -316,6 +316,57 @@ class TestTftpyState(unittest.TestCase):
finalstate = serverstate.state.handle(ack, raddress, rport)
self.assertTrue(finalstate is None)
+ def testServerTimeoutExpectACK(self):
+ raddress = "127.0.0.2"
+ rport = 10000
+ timeout = 5
+ root = os.path.dirname(os.path.abspath(__file__))
+ # Testing without the dyn_func_file set.
+ serverstate = tftpy.TftpContexts.TftpContextServer(
+ raddress, rport, timeout, root
+ )
+
+ self.assertTrue(isinstance(serverstate, tftpy.TftpContexts.TftpContextServer))
+
+ rrq = tftpy.TftpPacketTypes.TftpPacketRRQ()
+ rrq.filename = "640KBFILE"
+ rrq.mode = "octet"
+ rrq.options = {}
+
+ # Start the download.
+ serverstate.start(rrq.encode().buffer)
+
+ ack = tftpy.TftpPacketTypes.TftpPacketACK()
+ ack.blocknumber = 1
+
+ # Server expects ACK at the beginning of transmission
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+
+ # Receive first ACK for block 1, next block expected is 2
+ serverstate.state = serverstate.state.handle(ack, raddress, rport)
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+ self.assertEqual(serverstate.state.context.next_block, 2)
+
+ # Receive duplicate ACK for block 1, next block expected is still 2
+ serverstate.state = serverstate.state.handle(ack, raddress, rport)
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+ self.assertEqual(serverstate.state.context.next_block, 2)
+
+ # Receive duplicate ACK for block 1 after timeout for resending block 2
+ serverstate.state.context.metrics.last_dat_time -= 10 # Simulate 10 seconds time warp
+ self.assertRaises(
+ tftpy.TftpTimeoutExpectACK, serverstate.state.handle, ack, raddress, rport
+ )
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+
def testServerNoOptionsSubdir(self):
raddress = "127.0.0.2"
rport = 10000
diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py
index 0a82c8d..36cd7b3 100644
--- a/tftpy/TftpContexts.py
+++ b/tftpy/TftpContexts.py
@@ -43,6 +43,7 @@ class TftpMetrics:
self.start_time = 0
self.end_time = 0
self.duration = 0
+ self.last_dat_time = 0
# Rates
self.bps = 0
self.kbps = 0
@@ -112,6 +113,9 @@ class TftpContext:
self.last_pkt = None
# Count the number of retry attempts.
self.retry_count = 0
+ # Flag to signal timeout error when waiting for ACK of the current block
+ # and at the same time receiving duplicate ACK of previous block
+ self.timeout_expectACK = False
def getBlocksize(self):
"""Fetch the current blocksize for this session."""
@@ -127,6 +131,8 @@ class TftpContext:
"""Compare current time with last_update time, and raise an exception
if we're over the timeout time."""
log.debug("checking for timeout on session %s", self)
+ if self.timeout_expectACK:
+ raise TftpTimeout("Timeout waiting for traffic")
if now - self.last_update > self.timeout:
raise TftpTimeout("Timeout waiting for traffic")
diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py
index 56e7944..cb01d18 100644
--- a/tftpy/TftpServer.py
+++ b/tftpy/TftpServer.py
@@ -180,6 +180,8 @@ class TftpServer(TftpSession):
)
try:
self.sessions[key].start(buffer)
+ except TftpTimeoutExpectACK:
+ self.sessions[key].timeout_expectACK = True
except TftpException as err:
deletion_list.append(key)
log.error(
@@ -199,11 +201,14 @@ class TftpServer(TftpSession):
for key in self.sessions:
if readysock == self.sessions[key].sock:
log.debug("Matched input to session key %s" % key)
+ self.sessions[key].timeout_expectACK = False
try:
self.sessions[key].cycle()
if self.sessions[key].state is None:
log.info("Successful transfer.")
deletion_list.append(key)
+ except TftpTimeoutExpectACK:
+ self.sessions[key].timeout_expectACK = True
except TftpException as err:
deletion_list.append(key)
log.error(
diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py
index d464d4d..6c129b9 100644
--- a/tftpy/TftpShared.py
+++ b/tftpy/TftpShared.py
@@ -52,6 +52,13 @@ class TftpTimeout(TftpException):
pass
+class TftpTimeoutExpectACK(TftpTimeout):
+ """This class represents a timeout error when waiting for ACK of the current block
+ and receiving duplicate ACK for previous block from the other end."""
+
+ pass
+
+
class TftpFileNotFoundError(TftpException):
"""This class represents an error condition where we received a file
not found error."""
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py
index 15d52f9..ca303c3 100644
--- a/tftpy/TftpStates.py
+++ b/tftpy/TftpStates.py
@@ -12,6 +12,7 @@ error, in which case a TftpException is returned instead."""
import logging
import os
+import time
from .TftpPacketTypes import *
from .TftpShared import *
@@ -92,8 +93,6 @@ class TftpState:
blocknumber = self.context.next_block
# Test hook
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
- import time
-
log.debug("Deliberately delaying 10 seconds...")
time.sleep(10)
dat = None
@@ -111,6 +110,7 @@ class TftpState:
self.context.sock.sendto(
dat.encode().buffer, (self.context.host, self.context.tidport)
)
+ self.context.metrics.last_dat_time = time.time()
if self.context.packethook:
self.context.packethook(dat)
self.context.last_pkt = dat
@@ -465,6 +465,9 @@ class TftpStateExpectACK(TftpState):
elif pkt.blocknumber < self.context.next_block:
log.warning("Received duplicate ACK for block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt)
+ if self.context.metrics.last_dat_time > 0:
+ if time.time() - self.context.metrics.last_dat_time > self.context.timeout:
+ raise TftpTimeoutExpectACK("Timeout waiting for ACK for block %d" % self.context.next_block)
else:
log.warning(