summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael P. Soulier <msoulier@digitaltorque.ca>2022-04-18 10:24:25 -0400
committerMichael P. Soulier <msoulier@digitaltorque.ca>2022-04-18 10:24:25 -0400
commitf0588430217f44730eb25db95f341d548a816488 (patch)
tree648747946ffcbcb6b71780e6370a5502f1673560
parent57feb2c3416c59d6c0a2061d7bc2e56bc7c87ad5 (diff)
parenta6ba72c2e6d0ce49f904556e9c1d1b765afd9e3a (diff)
downloadtftpy-f0588430217f44730eb25db95f341d548a816488.tar.gz
Merged PR 133 - handling duplicate ACK
-rw-r--r--t/test.py48
-rw-r--r--tftpy/TftpContexts.py6
-rw-r--r--tftpy/TftpServer.py5
-rw-r--r--tftpy/TftpShared.py7
-rw-r--r--tftpy/TftpStates.py6
5 files changed, 53 insertions, 19 deletions
diff --git a/t/test.py b/t/test.py
index 6a2ef61..22502ab 100644
--- a/t/test.py
+++ b/t/test.py
@@ -316,9 +316,7 @@ class TestTftpyState(unittest.TestCase):
finalstate = serverstate.state.handle(ack, raddress, rport)
self.assertTrue(finalstate is None)
- def testServerNoOptionsUnreliable(self):
- log.debug("===> Running testcase testClientServerNoOptionsUnreliable")
- tftpy.TftpStates.NETWORK_UNRELIABILITY = 1000
+ def testServerTimeoutExpectACK(self):
raddress = "127.0.0.2"
rport = 10000
timeout = 5
@@ -337,23 +335,37 @@ class TestTftpyState(unittest.TestCase):
# Start the download.
serverstate.start(rrq.encode().buffer)
- # At a 512 byte blocksize, this should be 1280 packets exactly.
- for block in range(1, 1281):
- # Should be in expectack state.
- self.assertTrue(
- isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
- )
- ack = tftpy.TftpPacketTypes.TftpPacketACK()
- ack.blocknumber = block % 65536
- serverstate.state = serverstate.state.handle(ack, raddress, rport)
- # The last DAT packet should be empty, indicating a completed
- # transfer.
ack = tftpy.TftpPacketTypes.TftpPacketACK()
- ack.blocknumber = 1281 % 65536
- finalstate = serverstate.state.handle(ack, raddress, rport)
- self.assertTrue(finalstate is None)
- tftpy.TftpStates.NETWORK_UNRELIABILITY = 0
+ ack.blocknumber = 1
+
+ # Server expects ACK at the beginning of transmission
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+
+ # Receive first ACK for block 1, next block expected is 2
+ serverstate.state = serverstate.state.handle(ack, raddress, rport)
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+ self.assertEqual(serverstate.state.context.next_block, 2)
+
+ # Receive duplicate ACK for block 1, next block expected is still 2
+ serverstate.state = serverstate.state.handle(ack, raddress, rport)
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
+ self.assertEqual(serverstate.state.context.next_block, 2)
+
+ # Receive duplicate ACK for block 1 after timeout for resending block 2
+ serverstate.state.context.metrics.last_dat_time -= 10 # Simulate 10 seconds time warp
+ self.assertRaises(
+ tftpy.TftpTimeoutExpectACK, serverstate.state.handle, ack, raddress, rport
+ )
+ self.assertTrue(
+ isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK)
+ )
def testServerNoOptionsSubdir(self):
raddress = "127.0.0.2"
diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py
index d4f127c..ab0f9f8 100644
--- a/tftpy/TftpContexts.py
+++ b/tftpy/TftpContexts.py
@@ -43,6 +43,7 @@ class TftpMetrics:
self.start_time = 0
self.end_time = 0
self.duration = 0
+ self.last_dat_time = 0
# Rates
self.bps = 0
self.kbps = 0
@@ -112,6 +113,9 @@ class TftpContext:
self.last_pkt = None
# Count the number of retry attempts.
self.retry_count = 0
+ # Flag to signal timeout error when waiting for ACK of the current block
+ # and at the same time receiving duplicate ACK of previous block
+ self.timeout_expectACK = False
def getBlocksize(self):
"""Fetch the current blocksize for this session."""
@@ -127,6 +131,8 @@ class TftpContext:
"""Compare current time with last_update time, and raise an exception
if we're over the timeout time."""
log.debug("checking for timeout on session %s", self)
+ if self.timeout_expectACK:
+ raise TftpTimeout("Timeout waiting for traffic")
if now - self.last_update > self.timeout:
raise TftpTimeout("Timeout waiting for traffic")
diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py
index 56e7944..cb01d18 100644
--- a/tftpy/TftpServer.py
+++ b/tftpy/TftpServer.py
@@ -180,6 +180,8 @@ class TftpServer(TftpSession):
)
try:
self.sessions[key].start(buffer)
+ except TftpTimeoutExpectACK:
+ self.sessions[key].timeout_expectACK = True
except TftpException as err:
deletion_list.append(key)
log.error(
@@ -199,11 +201,14 @@ class TftpServer(TftpSession):
for key in self.sessions:
if readysock == self.sessions[key].sock:
log.debug("Matched input to session key %s" % key)
+ self.sessions[key].timeout_expectACK = False
try:
self.sessions[key].cycle()
if self.sessions[key].state is None:
log.info("Successful transfer.")
deletion_list.append(key)
+ except TftpTimeoutExpectACK:
+ self.sessions[key].timeout_expectACK = True
except TftpException as err:
deletion_list.append(key)
log.error(
diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py
index e727c94..5603ba6 100644
--- a/tftpy/TftpShared.py
+++ b/tftpy/TftpShared.py
@@ -57,6 +57,13 @@ class TftpTimeout(TftpException):
pass
+class TftpTimeoutExpectACK(TftpTimeout):
+ """This class represents a timeout error when waiting for ACK of the current block
+ and receiving duplicate ACK for previous block from the other end."""
+
+ pass
+
+
class TftpFileNotFoundError(TftpException):
"""This class represents an error condition where we received a file
not found error."""
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py
index 1f98682..f6b4b48 100644
--- a/tftpy/TftpStates.py
+++ b/tftpy/TftpStates.py
@@ -13,6 +13,7 @@ error, in which case a TftpException is returned instead."""
import logging
import os
import random
+import time
from .TftpPacketTypes import *
from .TftpShared import *
@@ -93,7 +94,6 @@ class TftpState:
blocknumber = self.context.next_block
# Test hook
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
- import time
log.debug("Deliberately delaying 10 seconds...")
time.sleep(10)
dat = None
@@ -115,6 +115,7 @@ class TftpState:
self.context.sock.sendto(
dat.encode().buffer, (self.context.host, self.context.tidport)
)
+ self.context.metrics.last_dat_time = time.time()
if self.context.packethook:
self.context.packethook(dat)
self.context.last_pkt = dat
@@ -474,6 +475,9 @@ class TftpStateExpectACK(TftpState):
elif pkt.blocknumber < self.context.next_block:
log.warning("Received duplicate ACK for block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt)
+ if self.context.metrics.last_dat_time > 0:
+ if time.time() - self.context.metrics.last_dat_time > self.context.timeout:
+ raise TftpTimeoutExpectACK("Timeout waiting for ACK for block %d" % self.context.next_block)
else:
log.warning(