summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael P. Soulier <msoulier@digitaltorque.ca>2011-07-23 19:40:53 -0400
committerMichael P. Soulier <msoulier@digitaltorque.ca>2011-07-23 19:40:53 -0400
commit1e74abf010088abd4bab27de74778e41393911dd (patch)
tree65c05d85e87f92991b9bbf767cc4292628a4c80a
parent6fd9391ad86fe58cf73dabce452d5d14c0d9ac32 (diff)
downloadtftpy-1e74abf010088abd4bab27de74778e41393911dd.tar.gz
Adding retries on timeouts, still have to exhaustively test.
Should close issue #21 on github.
-rw-r--r--tftpy/TftpServer.py12
-rw-r--r--tftpy/TftpShared.py5
-rw-r--r--tftpy/TftpStates.py82
3 files changed, 74 insertions, 25 deletions
diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py
index 9e64d83..46c662b 100644
--- a/tftpy/TftpServer.py
+++ b/tftpy/TftpServer.py
@@ -149,9 +149,17 @@ class TftpServer(TftpSession):
for key in self.sessions:
try:
self.sessions[key].checkTimeout(now)
- except TftpException, err:
+ except TftpTimeout, err:
log.error(str(err))
- deletion_list.append(key)
+ self.sessions[key].retry_count += 1
+ if self.sessions[key].retry_count >= TIMEOUT_RETRIES:
+ log.debug("hit max retries on %s, giving up"
+ % self.sessions[key])
+ deletion_list.append(key)
+ else:
+ log.debug("resending on session %s"
+ % self.sessions[key])
+ self.sessions[key].state.resendLast()
log.debug("Iterating deletion list.")
for key in deletion_list:
diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py
index 69ade90..1039ed2 100644
--- a/tftpy/TftpShared.py
+++ b/tftpy/TftpShared.py
@@ -49,3 +49,8 @@ class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
of the TFTP protocol."""
pass
+
+class TftpTimeout(TftpException):
+ """This class represents a timeout error waiting for a response from the
+ other end."""
+ pass
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py
index 6c77499..0992d6c 100644
--- a/tftpy/TftpStates.py
+++ b/tftpy/TftpStates.py
@@ -86,14 +86,16 @@ class TftpContext(object):
self.tidport = None
# Metrics
self.metrics = TftpMetrics()
- # Flag when the transfer is pending completion.
+ # Fluag when the transfer is pending completion.
self.pending_complete = False
# Time when this context last received any traffic.
# FIXME: does this belong in metrics?
self.last_update = 0
- # The last DAT packet we sent, if applicable, to make resending easy.
- self.last_dat_pkt = None
+ # The last packet we sent, if applicable, to make resending easy.
+ self.last_pkt = None
self.dyn_file_func = dyn_file_func
+ # Count the number of retry attempts.
+ self.retry_count = 0
def __del__(self):
"""Simple destructor to try to call housekeeping in the end method if
@@ -104,8 +106,9 @@ class TftpContext(object):
def checkTimeout(self, now):
"""Compare current time with last_update time, and raise an exception
if we're over SOCK_TIMEOUT time."""
+ log.debug("checking for timeout on session %s" % self)
if now - self.last_update > SOCK_TIMEOUT:
- raise TftpException, "Timeout waiting for traffic"
+ raise TftpTimeout, "Timeout waiting for traffic"
def start(self):
raise NotImplementedError, "Abstract method"
@@ -145,19 +148,11 @@ class TftpContext(object):
def cycle(self):
"""Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response."""
- # FIXME: This won't work very well in a server context with multiple
- # sessions running.
- for i in range(TIMEOUT_RETRIES):
- log.debug("In cycle, receive attempt %d" % i)
- try:
- (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
- except socket.timeout, err:
- log.warn("Timeout waiting for traffic, retrying...")
- continue
- break
- else:
- self.sock.close()
- raise TftpException, "Hit max timeouts, giving up."
+ try:
+ (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
+ except socket.timeout, err:
+ log.warn("Timeout waiting for traffic, retrying...")
+ raise TftpTimeout, "Timed-out waiting for traffic"
# Ok, we've received a packet. Log it.
log.debug("Received %d bytes from %s:%s"
@@ -188,6 +183,9 @@ class TftpContext(object):
# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
+ # If we didn't throw any exceptions here, reset the retry_count to
+ # zero.
+ self.retry_count = 0
class TftpContextServer(TftpContext):
"""The context for the server."""
@@ -279,12 +277,25 @@ class TftpContextClientUpload(TftpContext):
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.next_block = 1
+ self.last_pkt = pkt
+ # FIXME: should we centralize sendto operations so we can refactor all
+ # saving of the packet to the last_pkt field?
self.state = TftpStateSentWRQ(self)
while self.state:
- log.debug("State is %s" % self.state)
- self.cycle()
+ try:
+ log.debug("State is %s" % self.state)
+ self.cycle()
+ except TftpTimeout, err:
+ log.error(str(err))
+ self.retry_count += 1
+ if self.retry_count >= TIMEOUT_RETRIES:
+ log.debug("hit max retries, giving up")
+ raise
+ else:
+ log.warn("resending last packet")
+ self.state.resendLast()
def end(self):
"""Finish up the context."""
@@ -343,12 +354,23 @@ class TftpContextClientDownload(TftpContext):
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.next_block = 1
+ self.last_pkt = pkt
self.state = TftpStateSentRRQ(self)
while self.state:
- log.debug("State is %s" % self.state)
- self.cycle()
+ try:
+ log.debug("State is %s" % self.state)
+ self.cycle()
+ except TftpTimeout, err:
+ log.error(str(err))
+ self.retry_count += 1
+ if self.retry_count >= TIMEOUT_RETRIES:
+ log.debug("hit max retries, giving up")
+ raise
+ else:
+ log.warn("resending last packet")
+ self.state.resendLast()
def end(self):
"""Finish up the context."""
@@ -479,7 +501,7 @@ class TftpState(object):
dat = None
if resend:
log.warn("Resending block number %d" % blocknumber)
- dat = self.context.last_dat_pkt
+ dat = self.context.last_pkt
self.context.metrics.resent_bytes += len(dat.data)
self.context.metrics.add_dup(dat)
else:
@@ -499,7 +521,7 @@ class TftpState(object):
(self.context.host, self.context.tidport))
if self.context.packethook:
self.context.packethook(dat)
- self.context.last_dat_pkt = dat
+ self.context.last_pkt = dat
return finished
def sendACK(self, blocknumber=None):
@@ -515,6 +537,7 @@ class TftpState(object):
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
+ self.last_pkt = ackpkt
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
@@ -525,6 +548,7 @@ class TftpState(object):
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
+ self.last_pkt = errpkt
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
@@ -535,6 +559,18 @@ class TftpState(object):
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
+ self.last_pkt = pkt
+
+ def resendLast(self):
+ "Resend the last sent packet due to a timeout."
+ log.warn("Resending packet %s on sessions %s"
+ % (self.last_pkt, self))
+ self.context.metrics.resent_bytes += len(self.last_pkt.data)
+ self.context.metrics.add_dup(self.last_pkt)
+ self.context.sock.sendto(self.last_pkt.encode().buffer,
+ (self.context.host, self.context.tidport))
+ if self.context.packethook:
+ self.context.packethook(self.last_pkt)
def handleDat(self, pkt):
"""This method handles a DAT packet during a client download, or a