From 5072f6d93c6fe5ba4f215e2fe6d646594714ef50 Mon Sep 17 00:00:00 2001 From: "Michael P. Soulier" Date: Fri, 10 Apr 2009 20:54:20 -0400 Subject: Fixed TftpClient with new state machine. --- tftpy/TftpClient.py | 9 +++---- tftpy/TftpPacketFactory.py | 1 - tftpy/TftpStates.py | 60 +++++++++++++++++++++++++++------------------- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index 5947367..c5b0c1f 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -15,11 +15,8 @@ class TftpClient(TftpSession): self.iport = port self.filename = None self.options = options - self.blocknumber = 0 - self.fileobj = None - self.timesent = 0 - self.buffer = None - self.bytes = 0 + # FIXME: If the blksize is DEF_BLKSIZE, we should just skip sending + # it. if self.options.has_key('blksize'): size = self.options['blksize'] tftpassert(types.IntType == type(size), "blksize must be an int") @@ -74,7 +71,7 @@ class TftpClient(TftpSession): logger.info("Duration too short, rate undetermined") else: logger.info('') - logger.info("Downloaded %d bytes in %d seconds" % (metrics.bytes, metrics.duration)) + logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) logger.info("Average rate: %.2f kbps" % metrics.kbps) logger.info("Received %d duplicate packets" % metrics.dupcount) diff --git a/tftpy/TftpPacketFactory.py b/tftpy/TftpPacketFactory.py index 366d546..642b4d8 100644 --- a/tftpy/TftpPacketFactory.py +++ b/tftpy/TftpPacketFactory.py @@ -34,5 +34,4 @@ class TftpPacketFactory(object): packet = self.classes[opcode]() - logger.debug("packet is %s" % packet) return packet diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 090ff4a..312e623 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -25,11 +25,13 @@ class TftpMetrics(object): def compute(self): # Compute transfer time - self.duration = int(self.end_time - self.start_time) - self.bps = (metrics.bytes * 8.0) / metrics.duration - self.kbps = bps / 1024.0 + self.duration = self.end_time - self.start_time + logger.debug("TftpMetrics.compute: duration is %s" % self.duration) + self.bps = (self.bytes * 8.0) / self.duration + self.kbps = self.bps / 1024.0 + logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps) for key in self.dups: - dupcount += metrics.dups[key] + dupcount += self.dups[key] ############################################################################### # Context classes @@ -70,7 +72,7 @@ class TftpContext(object): """This method sends an ack packet to the block number specified.""" logger.info("sending ack to block %d" % blocknumber) ackpkt = TftpPacketACK() - ackpkt.blocknumber = 0 + ackpkt.blocknumber = blocknumber self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) def senderror(self, errorcode): @@ -89,7 +91,6 @@ class TftpContextClientDownload(TftpContext): """The download context for the client during a download.""" def __init__(self, host, port, filename, output, options, packethook, timeout): TftpContext.__init__(self, host, port) - # Open the output file. # FIXME - need to support alternate return formats than files? # File-like objects would be ideal, ala duck-typing. self.requested_file = filename @@ -102,6 +103,13 @@ class TftpContextClientDownload(TftpContext): self.state = None self.expected_block = 0 + ############################ + # Logging + ############################ + logger.debug("TftpContextClientDownload.__init__()") + logger.debug("requested_file = %s, options = %s" % + (self.requested_file, self.options)) + def setExpectedBlock(self, block): if block > 2 ** 16: logger.debug("block number rollover to 0 again") @@ -115,10 +123,11 @@ class TftpContextClientDownload(TftpContext): def start(self): """Initiate the download.""" - logger.info("Sending tftp download request to %s" % self.host) + logger.info("sending tftp download request to %s" % self.host) logger.info(" filename -> %s" % self.requested_file) self.metrics.start_time = time.time() + logger.debug("set metrics.start_time to %s" % self.metrics.start_time) # FIXME: put this in a sendRRQ method? pkt = TftpPacketRRQ() @@ -132,6 +141,7 @@ class TftpContextClientDownload(TftpContext): try: while self.state: + logger.debug("state is %s" % self.state) self.cycle() finally: self.fileobj.close() @@ -139,12 +149,14 @@ class TftpContextClientDownload(TftpContext): def end(self): """Finish up the context.""" self.metrics.end_time = time.time() + logger.debug("set metrics.end_time to %s" % self.metrics.end_time) self.metrics.compute() def cycle(self): """Here we wait for a response from the server after sending it something, and dispatch appropriate action to that response.""" for i in range(TIMEOUT_RETRIES): + logger.debug("in cycle, receive attempt %d" % i) try: (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) except socket.timeout, err: @@ -154,13 +166,13 @@ class TftpContextClientDownload(TftpContext): else: raise TftpException, "Hit max timeouts, giving up." - # Ok, we've received a packet. Decode it. - recvpkt = self.factory.parse(buffer) - - # Log it. + # Ok, we've received a packet. Log it. logger.debug("Received %d bytes from %s:%s" % (len(buffer), raddress, rport)) + # Decode it. + recvpkt = self.factory.parse(buffer) + # Check for known "connection". if raddress != self.address: logger.warn("Received traffic from %s, expected host %s. Discarding" @@ -205,13 +217,13 @@ class TftpStateDownload(TftpState): def handleDat(self, pkt): """This method handles a DAT packet during a download.""" logger.info("handling DAT packet - block %d" % pkt.blocknumber) - logger.debug("expecting block %s" % self.expected_block) - if pkt.blocknumber == self.expected_block: + logger.debug("expecting block %s" % self.context.expected_block) + if pkt.blocknumber == self.context.expected_block: logger.debug("good, received block %d in sequence" % pkt.blocknumber) self.context.sendAck(pkt.blocknumber) - self.expected_block += 1 + self.context.expected_block += 1 logger.debug("writing %d bytes to output file" % len(pkt.data)) @@ -222,22 +234,22 @@ class TftpStateDownload(TftpState): logger.info("end of file detected") return None - elif pkt.blocknumber == curblock: + elif pkt.blocknumber < self.context.expected_block: logger.warn("dropping duplicate block %d" % pkt.blocknumber) - if self.context.metrics.dups.has_key(curblock): + if self.context.metrics.dups.has_key(pkt.blocknumber): self.context.metrics.dups[pkt.blocknumber] += 1 else: self.context.metrics.dups[pkt.blocknumber] = 1 - tftpassert(self.context.metrics.dups[curblock] < MAX_DUPS, - "Max duplicates for block %d reached" % curblock) + tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS, + "Max duplicates for block %d reached" % pkt.blocknumber) # FIXME: double-check sorceror's apprentice problem! - logger.debug("ACKing block %d again, just in case" % curblock) + logger.debug("ACKing block %d again, just in case" % pkt.blocknumber) self.context.sendAck(pkt.blocknumber) else: # FIXME: should we be more tolerant and just discard instead? - msg = "Whoa! Received block %d but expected %d" % (pkt.blocknumber, - self.expected_block) + msg = "Whoa! Received future block %d but expected %d" \ + % (pkt.blocknumber, self.context.expected_block) logger.error(msg) raise TftpException, msg @@ -249,8 +261,8 @@ class TftpStateSentRRQ(TftpStateDownload): def handle(self, pkt, raddress, rport): """Handle the packet in response to an RRQ to the server.""" - if not self.tidport: - self.tidport = rport + if not self.context.tidport: + self.context.tidport = rport logger.debug("Set remote port for session to %s" % rport) # Now check the packet type and dispatch it properly. @@ -298,7 +310,7 @@ class TftpStateSentRRQ(TftpStateDownload): # By default, no state change. return self -class TftpStateSentACK(TftpState): +class TftpStateSentACK(TftpStateDownload): """Just sent an ACK packet. Waiting for DAT.""" def handle(self, pkt, raddress, rport): """Handle the packet in response to an ACK, which should be a DAT.""" -- cgit v1.2.1