diff options
Diffstat (limited to 'tftpy/TftpStates.py')
-rw-r--r-- | tftpy/TftpStates.py | 352 |
1 files changed, 255 insertions, 97 deletions
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 398f137..88c4fa1 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -56,17 +56,17 @@ class TftpContext(object): def end(self): return NotImplementedError, "Abstract method" - + def gethost(self): "Simple getter method for use in a property." return self.__host - + def sethost(self, host): """Setter method that also sets the address property as a result of the host that is set.""" self.__host = host self.address = socket.gethostbyname(host) - + host = property(gethost, sethost) def sendAck(self, blocknumber): @@ -76,83 +76,37 @@ class TftpContext(object): ackpkt.blocknumber = blocknumber self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport)) - def senderror(self, errorcode): + def sendError(self, errorcode): """This method uses the socket passed, and uses the errorcode to compose and send an error packet.""" - logger.debug("In senderror, being asked to send error %d" % errorcode) + logger.debug("In sendError, being asked to send error %d" % errorcode) errpkt = TftpPacketERR() errpkt.errorcode = errorcode - sock.sendto(errpkt.encode().buffer, (self.host, self.tidport)) + self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport)) -class TftpContextServerDownload(TftpContext): - """The download context for the server during a download.""" - pass - -class TftpContextClientDownload(TftpContext): - """The download context for the client during a download.""" - def __init__(self, host, port, filename, output, options, packethook, timeout): +class TftpContextClient(TftpContext): + """This class represents shared functionality by both the download and + upload client contexts.""" + def __init__(self, host, port, filename, options, packethook, timeout): TftpContext.__init__(self, host, port) - # FIXME - need to support alternate return formats than files? - # File-like objects would be ideal, ala duck-typing. - self.requested_file = filename - self.fileobj = open(output, "wb") + self.file_to_transfer = filename self.options = options self.packethook = packethook self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock.settimeout(timeout) - self.state = None - self.expected_block = 0 - - ############################ - # Logging - ############################ - logger.debug("TftpContextClientDownload.__init__()") - logger.debug("requested_file = %s, options = %s" % - (self.requested_file, self.options)) + self.next_block = 0 - def setExpectedBlock(self, block): + def setNextBlock(self, block): if block > 2 ** 16: logger.debug("block number rollover to 0 again") block = 0 self.__eblock = block - def getExpectedBlock(self): + def getNextBlock(self): return self.__eblock - expected_block = property(getExpectedBlock, setExpectedBlock) - - def start(self): - """Initiate the download.""" - logger.info("sending tftp download request to %s" % self.host) - logger.info(" filename -> %s" % self.requested_file) - logger.info(" options -> %s" % self.options) - - self.metrics.start_time = time.time() - logger.debug("set metrics.start_time to %s" % self.metrics.start_time) - - # FIXME: put this in a sendRRQ method? - pkt = TftpPacketRRQ() - pkt.filename = self.requested_file - pkt.mode = "octet" # FIXME - shouldn't hardcode this - pkt.options = self.options - self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) - self.expected_block = 1 - - self.state = TftpStateSentRRQ(self) - - try: - while self.state: - logger.debug("state is %s" % self.state) - self.cycle() - finally: - self.fileobj.close() - - def end(self): - """Finish up the context.""" - self.metrics.end_time = time.time() - logger.debug("set metrics.end_time to %s" % self.metrics.end_time) - self.metrics.compute() + next_block = property(getNextBlock, setNextBlock) def cycle(self): """Here we wait for a response from the server after sending it @@ -169,7 +123,7 @@ class TftpContextClientDownload(TftpContext): raise TftpException, "Hit max timeouts, giving up." # Ok, we've received a packet. Log it. - logger.debug("Received %d bytes from %s:%s" + logger.debug("Received %d bytes from %s:%s" % (len(buffer), raddress, rport)) # Decode it. @@ -196,6 +150,101 @@ class TftpContextClientDownload(TftpContext): # And handle it, possibly changing state. self.state = self.state.handle(recvpkt, raddress, rport) +class TftpContextClientUpload(TftpContextClient): + """The upload context for the client during an upload.""" + def __init__(self, host, port, filename, input, options, packethook, timeout): + TftpContextClient.__init__(self, + host, + port, + filename, + options, + packethook, + timeout) + self.fileobj = open(input, "wb") + + logger.debug("TftpContextClientUpload.__init__()") + logger.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def start(self): + logger.info("sending tftp upload request to %s" % self.host) + logger.info(" filename -> %s" % self.file_to_transfer) + logger.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendWRQ method? + pkt = TftpPacketWRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + + self.state = TftpStateSentWRQ(self) + + try: + while self.state: + logger.debug("state is %s" % self.state) + self.cycle() + finally: + self.fileobj.close() + + def end(self): + pass + +class TftpContextClientDownload(TftpContextClient): + """The download context for the client during a download.""" + def __init__(self, host, port, filename, output, options, packethook, timeout): + TftpContextClient.__init__(self, + host, + port, + filename, + options, + packethook, + timeout) + # FIXME - need to support alternate return formats than files? + # File-like objects would be ideal, ala duck-typing. + self.fileobj = open(output, "wb") + + logger.debug("TftpContextClientDownload.__init__()") + logger.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def start(self): + """Initiate the download.""" + logger.info("sending tftp download request to %s" % self.host) + logger.info(" filename -> %s" % self.file_to_transfer) + logger.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendRRQ method? + pkt = TftpPacketRRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + + self.state = TftpStateSentRRQ(self) + + try: + while self.state: + logger.debug("state is %s" % self.state) + self.cycle() + finally: + self.fileobj.close() + + def end(self): + """Finish up the context.""" + self.metrics.end_time = time.time() + logger.debug("set metrics.end_time to %s" % self.metrics.end_time) + self.metrics.compute() + + ############################################################################### # State classes ############################################################################### @@ -214,20 +263,62 @@ class TftpState(object): a TftpState object, either itself or a new state.""" raise NotImplementedError, "Abstract method" + def handleOACK(self, pkt): + """This method handles an OACK from the server, syncing any accepted + options.""" + if pkt.options.keys() > 0: + if pkt.match_options(self.context.options): + logger.info("Successful negotiation of options") + # Set options to OACK options + self.context.options = pkt.options + for key in self.context.options: + logger.info(" %s = %s" % (key, self.context.options[key])) + else: + logger.error("failed to negotiate options") + raise TftpException, "Failed to negotiate options" + else: + raise TftpException, "No options found in OACK" + +class TftpStateUpload(TftpState): + """A class holding common code for upload states.""" + def sendDat(self, resend=False): + finished = False + blocknumber = self.context.next_block + if not resend: + blksize = int(self.context.options['blksize']) + buffer = self.context.fileobj.read(blksize) + logger.debug("Read %d bytes into buffer" % len(buffer)) + if len(buffer) < blksize: + logger.info("Reached EOF on file %s" % self.context.input) + finished = True + self.context.next_block += 1 + self.bytes += len(buffer) + else: + logger.warn("Resending block number %d" % blocknumber) + dat = TftpPacketDAT() + dat.data = buffer + dat.blocknumber = blocknumber + logger.debug("Sending DAT packet %d" % blocknumber) + self.context.sock.sendto(dat.encode().buffer, + (self.context.host, self.context.port)) + if self.context.packethook: + self.context.packethook(dat) + return finished + class TftpStateDownload(TftpState): """A class holding common code for download states.""" def handleDat(self, pkt): """This method handles a DAT packet during a download.""" logger.info("handling DAT packet - block %d" % pkt.blocknumber) - logger.debug("expecting block %s" % self.context.expected_block) - if pkt.blocknumber == self.context.expected_block: - logger.debug("good, received block %d in sequence" + logger.debug("expecting block %s" % self.context.next_block) + if pkt.blocknumber == self.context.next_block: + logger.debug("good, received block %d in sequence" % pkt.blocknumber) - + self.context.sendAck(pkt.blocknumber) - self.context.expected_block += 1 + self.context.next_block += 1 - logger.debug("writing %d bytes to output file" + logger.debug("writing %d bytes to output file" % len(pkt.data)) self.context.fileobj.write(pkt.data) self.context.metrics.bytes += len(pkt.data) @@ -236,7 +327,7 @@ class TftpStateDownload(TftpState): logger.info("end of file detected") return None - elif pkt.blocknumber < self.context.expected_block: + elif pkt.blocknumber < self.context.next_block: logger.warn("dropping duplicate block %d" % pkt.blocknumber) if self.context.metrics.dups.has_key(pkt.blocknumber): self.context.metrics.dups[pkt.blocknumber] += 1 @@ -251,16 +342,87 @@ class TftpStateDownload(TftpState): else: # FIXME: should we be more tolerant and just discard instead? msg = "Whoa! Received future block %d but expected %d" \ - % (pkt.blocknumber, self.context.expected_block) + % (pkt.blocknumber, self.context.next_block) logger.error(msg) raise TftpException, msg # Default is to ack return TftpStateSentACK(self.context) +class TftpStateSentWRQ(TftpStateUpload): + """Just sent an WRQ packet for an upload.""" + def handle(self, pkt, raddress, rport): + """Handle a packet we just received.""" + if not self.context.tidport: + self.context.tidport = rport + logger.debug("Set remote port for session to %s" % rport) + + # If we're going to successfully transfer the file, then we should see + # either an OACK for accepted options, or an ACK to ignore options. + if isinstance(pkt, TftpPacketOACK): + logger.info("received OACK from server") + try: + self.handleOACK(pkt) + except TftpException, err: + logger.error("failed to negotiate options") + self.context.sendError(TftpErrors.FailedNegotiation) + raise + else: + logger.debug("sending first DAT packet") + fin = self.context.sendDat() + if fin: + logger.info("Add done") + return None + else: + logger.debug("Changing state to TftpStateSentDAT") + return TftpStateSentDAT(self.context) + + elif isinstance(pkt, TftpPacketACK): + logger.info("received ACK from server") + logger.debug("apparently the server ignored our options") + # The block number should be zero. + if pkt.blocknumber == 0: + logger.debug("ack blocknumber is zero as expected") + logger.debug("sending first DAT packet") + fin = self.context.sendDat() + if fin: + logger.info("Add done") + return None + else: + logger.debug("Changing state to TftpStateSentDAT") + return TftpStateSentDAT(self.context) + else: + logger.warn("discarding ACK to block %s" % pkt.blocknumber) + logger.debug("still waiting for valid response from server") + return self + + elif isinstance(pkt, TftpPacketERR): + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received ERR from server: " + str(pkt) + + elif isinstance(pkt, TftpPacketRRQ): + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received RRQ from server while in upload" + + elif isinstance(pkt, TftpPacketDAT): + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received DAT from server while in upload" + + else: + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received unknown packet type from server: " + str(pkt) + + # By default, no state change. + return self + +class TftpStateSentDAT(TftpStateUpload): + """This class represents the state of the transfer when a DAT was just + sent, and we are waiting for an ACK from the server. This class is the + same one used by the client during the upload, and the server during the + download.""" + class TftpStateSentRRQ(TftpStateDownload): """Just sent an RRQ packet.""" - def handle(self, pkt, raddress, rport): """Handle the packet in response to an RRQ to the server.""" if not self.context.tidport: @@ -269,24 +431,20 @@ class TftpStateSentRRQ(TftpStateDownload): # Now check the packet type and dispatch it properly. if isinstance(pkt, TftpPacketOACK): - logger.info("received OACK from server.") - if pkt.options.keys() > 0: - if pkt.match_options(self.context.options): - logger.info("Successful negotiation of options") - # Set options to OACK options - self.context.options = pkt.options - for key in self.context.options: - logger.info(" %s = %s" % (key, self.context.options[key])) - logger.debug("sending ACK to OACK") - - self.context.sendAck(blocknumber=0) - - logger.debug("Changing state to TftpStateSentACK") - return TftpStateSentACK(self.context) - else: - logger.error("failed to negotiate options") - self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port) - raise TftpException, "Failed to negotiate options" + logger.info("received OACK from server") + try: + self.handleOACK(pkt) + except TftpException, err: + logger.error("failed to negotiate options: %s" % str(err)) + self.context.sendError(TftpErrors.FailedNegotiation) + raise + else: + logger.debug("sending ACK to OACK") + + self.context.sendAck(blocknumber=0) + + logger.debug("Changing state to TftpStateSentACK") + return TftpStateSentACK(self.context) elif isinstance(pkt, TftpPacketDAT): # If there are any options set, then the server didn't honour any @@ -300,19 +458,19 @@ class TftpStateSentRRQ(TftpStateDownload): # Every other packet type is a problem. elif isinstance(recvpkt, TftpPacketACK): # Umm, we ACK, the server doesn't. - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ACK from server while in download" elif isinstance(recvpkt, TftpPacketWRQ): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received WRQ from server while in download" elif isinstance(recvpkt, TftpPacketERR): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ERR from server: " + str(recvpkt) else: - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received unknown packet type from server: " + str(recvpkt) # By default, no state change. @@ -328,17 +486,17 @@ class TftpStateSentACK(TftpStateDownload): # Every other packet type is a problem. elif isinstance(recvpkt, TftpPacketACK): # Umm, we ACK, the server doesn't. - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ACK from server while in download" elif isinstance(recvpkt, TftpPacketWRQ): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received WRQ from server while in download" elif isinstance(recvpkt, TftpPacketERR): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ERR from server: " + str(recvpkt) else: - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received unknown packet type from server: " + str(recvpkt) |