diff options
author | Paul Weaver <paul.weaver@osirium.com> | 2016-07-07 09:32:15 +0100 |
---|---|---|
committer | Paul Weaver <paul.weaver@osirium.com> | 2016-07-07 09:32:15 +0100 |
commit | c5a7b52ec818e6c7b8b02f7a88a5a6236e9e56a2 (patch) | |
tree | 191cda0a7a349faf6b73cd9cbab7c6bd6252d18d | |
parent | b938db8647f47fa4a47716a1db7a6c6e09739c7f (diff) | |
parent | a68924f2a66916b2d967b011d981259752556ae4 (diff) | |
download | tftpy-c5a7b52ec818e6c7b8b02f7a88a5a6236e9e56a2.tar.gz |
Merge branch 'master' into dynamic
To resolve conflict with latest changes on master
Conflicts:
tftpy/TftpServer.py
-rw-r--r-- | t/test.py | 20 | ||||
-rw-r--r-- | tftpy/TftpClient.py | 11 | ||||
-rw-r--r-- | tftpy/TftpContexts.py | 21 | ||||
-rw-r--r-- | tftpy/TftpPacketFactory.py | 7 | ||||
-rw-r--r-- | tftpy/TftpPacketTypes.py | 61 | ||||
-rw-r--r-- | tftpy/TftpServer.py | 125 | ||||
-rw-r--r-- | tftpy/TftpShared.py | 3 | ||||
-rw-r--r-- | tftpy/TftpStates.py | 70 | ||||
-rw-r--r-- | tftpy/__init__.py | 17 |
9 files changed, 159 insertions, 176 deletions
@@ -26,8 +26,8 @@ class TestTftpyClasses(unittest.TestCase): rrq.encode() self.assert_(rrq.buffer != None, "Buffer populated") rrq.decode() - self.assertEqual(rrq.filename, "myfilename", "Filename correct") - self.assertEqual(rrq.mode, "octet", "Mode correct") + self.assertEqual(rrq.filename, b"myfilename", "Filename correct") + self.assertEqual(rrq.mode, b"octet", "Mode correct") self.assertEqual(rrq.options, options, "Options correct") # repeat test with options rrq.options = { 'blksize': '1024' } @@ -36,8 +36,8 @@ class TestTftpyClasses(unittest.TestCase): rrq.encode() self.assert_(rrq.buffer != None, "Buffer populated") rrq.decode() - self.assertEqual(rrq.filename, "myfilename", "Filename correct") - self.assertEqual(rrq.mode, "octet", "Mode correct") + self.assertEqual(rrq.filename, b"myfilename", "Filename correct") + self.assertEqual(rrq.mode, b"octet", "Mode correct") self.assertEqual(rrq.options['blksize'], '1024', "Blksize correct") def testTftpPacketWRQ(self): @@ -51,8 +51,8 @@ class TestTftpyClasses(unittest.TestCase): self.assert_(wrq.buffer != None, "Buffer populated") wrq.decode() self.assertEqual(wrq.opcode, 2, "Opcode correct") - self.assertEqual(wrq.filename, "myfilename", "Filename correct") - self.assertEqual(wrq.mode, "octet", "Mode correct") + self.assertEqual(wrq.filename, b"myfilename", "Filename correct") + self.assertEqual(wrq.mode, b"octet", "Mode correct") self.assertEqual(wrq.options, options, "Options correct") # repeat test with options wrq.options = { 'blksize': '1024' } @@ -62,8 +62,8 @@ class TestTftpyClasses(unittest.TestCase): self.assert_(wrq.buffer != None, "Buffer populated") wrq.decode() self.assertEqual(wrq.opcode, 2, "Opcode correct") - self.assertEqual(wrq.filename, "myfilename", "Filename correct") - self.assertEqual(wrq.mode, "octet", "Mode correct") + self.assertEqual(wrq.filename, b"myfilename", "Filename correct") + self.assertEqual(wrq.mode, b"octet", "Mode correct") self.assertEqual(wrq.options['blksize'], '1024', "Blksize correct") @@ -406,7 +406,7 @@ class TestTftpyState(unittest.TestCase): try: server.listen('localhost', 20001) log.error("server didn't throw exception") - except Exception, err: + except Exception as err: log.error("server got unexpected exception %s" % err) # Wait until parent kills us while True: @@ -448,7 +448,7 @@ class TestTftpyState(unittest.TestCase): signal.alarm(2) try: server.listen('localhost', 20001) - except Exception, err: + except Exception as err: log.error("server threw exception %s" % err) # Wait until parent kills us while True: diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index 62f1dda..6244051 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -2,10 +2,11 @@ instance of the client, and then use its upload or download method. Logging is performed via a standard logging object set in TftpShared.""" +from __future__ import absolute_import, division, print_function, unicode_literals import types -from TftpShared import * -from TftpPacketTypes import * -from TftpContexts import TftpContextClientDownload, TftpContextClientUpload +from .TftpShared import * +from .TftpPacketTypes import * +from .TftpContexts import TftpContextClientDownload, TftpContextClientUpload class TftpClient(TftpSession): """This class is an implementation of a tftp client. Once instantiated, a @@ -20,11 +21,11 @@ class TftpClient(TftpSession): self.filename = None self.options = options self.localip = localip - if self.options.has_key('blksize'): + if 'blksize' in self.options: size = self.options['blksize'] tftpassert(types.IntType == type(size), "blksize must be an int") if size < MIN_BLKSIZE or size > MAX_BLKSIZE: - raise TftpException, "Invalid blksize: %d" % size + raise TftpException("Invalid blksize: %d" % size) def download(self, filename, output, packethook=None, timeout=SOCK_TIMEOUT): """This method initiates a tftp download from the configured remote diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index b14fc39..57b9673 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -8,10 +8,11 @@ the next packet in the transfer, and returns a state object until the transfer is complete, at which point it returns None. That is, unless there is a fatal error, in which case a TftpException is returned instead.""" -from TftpShared import * -from TftpPacketTypes import * -from TftpPacketFactory import TftpPacketFactory -from TftpStates import * +from __future__ import absolute_import, division, print_function, unicode_literals +from .TftpShared import * +from .TftpPacketTypes import * +from .TftpPacketFactory import TftpPacketFactory +from .TftpStates import * import socket, time, sys ############################################################################### @@ -54,7 +55,7 @@ class TftpMetrics(object): """This method adds a dup for a packet to the metrics.""" log.debug("Recording a dup of %s", pkt) s = str(pkt) - if self.dups.has_key(s): + if s in self.dups: self.dups[s] += 1 else: self.dups[s] = 1 @@ -114,10 +115,10 @@ class TftpContext(object): if we're over the timeout time.""" log.debug("checking for timeout on session %s", self) if now - self.last_update > self.timeout: - raise TftpTimeout, "Timeout waiting for traffic" + raise TftpTimeout("Timeout waiting for traffic") def start(self): - raise NotImplementedError, "Abstract method" + raise NotImplementedError("Abstract method") def end(self): """Perform session cleanup, since the end method should always be @@ -159,7 +160,7 @@ class TftpContext(object): (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) except socket.timeout: log.warn("Timeout waiting for traffic, retrying...") - raise TftpTimeout, "Timed-out waiting for traffic" + raise TftpTimeout("Timed-out waiting for traffic") # Ok, we've received a packet. Log it. log.debug("Received %d bytes from %s:%s", @@ -307,7 +308,7 @@ class TftpContextClientUpload(TftpContext): try: log.debug("State is %s", self.state) self.cycle() - except TftpTimeout, err: + except TftpTimeout as err: log.error(str(err)) self.retry_count += 1 if self.retry_count >= TIMEOUT_RETRIES: @@ -386,7 +387,7 @@ class TftpContextClientDownload(TftpContext): try: log.debug("State is %s", self.state) self.cycle() - except TftpTimeout, err: + except TftpTimeout as err: log.error(str(err)) self.retry_count += 1 if self.retry_count >= TIMEOUT_RETRIES: diff --git a/tftpy/TftpPacketFactory.py b/tftpy/TftpPacketFactory.py index 154aec8..d0ed83e 100644 --- a/tftpy/TftpPacketFactory.py +++ b/tftpy/TftpPacketFactory.py @@ -2,8 +2,9 @@ buffer, and return the appropriate TftpPacket object to represent it, via the parse() method.""" -from TftpShared import * -from TftpPacketTypes import * +from __future__ import absolute_import, division, print_function, unicode_literals +from .TftpShared import * +from .TftpPacketTypes import * class TftpPacketFactory(object): """This class generates TftpPacket objects. It is responsible for parsing @@ -33,7 +34,7 @@ class TftpPacketFactory(object): def __create(self, opcode): """This method returns the appropriate class object corresponding to the passed opcode.""" - tftpassert(self.classes.has_key(opcode), + tftpassert(opcode in self.classes, "Unsupported opcode: %d" % opcode) packet = self.classes[opcode]() diff --git a/tftpy/TftpPacketTypes.py b/tftpy/TftpPacketTypes.py index 6a6bdeb..3cbb191 100644 --- a/tftpy/TftpPacketTypes.py +++ b/tftpy/TftpPacketTypes.py @@ -1,8 +1,10 @@ """This module implements the packet types of TFTP itself, and the corresponding encode and decode methods for them.""" +from __future__ import absolute_import, division, print_function, unicode_literals import struct -from TftpShared import * +import sys +from .TftpShared import * class TftpSession(object): """This class is the base class for the tftp client and server. Any shared @@ -63,7 +65,7 @@ class TftpPacketWithOptions(object): format += "%dsx" % length length = -1 else: - raise TftpException, "Invalid options in buffer" + raise TftpException("Invalid options in buffer") length += 1 log.debug("about to unpack, format is: %s", format) @@ -92,7 +94,7 @@ class TftpPacket(object): order suitable for sending over the wire. This is an abstract method.""" - raise NotImplementedError, "Abstract method" + raise NotImplementedError("Abstract method") def decode(self): """The decode method of a TftpPacket takes a buffer off of the wire in @@ -102,7 +104,7 @@ class TftpPacket(object): datagram. This is an abstract method.""" - raise NotImplementedError, "Abstract method" + raise NotImplementedError("Abstract method") class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): """This class is a common parent class for the RRQ and WRQ packets, as @@ -117,6 +119,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): """Encode the packet's buffer from the instance variables.""" tftpassert(self.filename, "filename required in initial packet") tftpassert(self.mode, "mode required in initial packet") + # Make sure filename and mode are bytestrings. + self.filename = self.filename.encode('ascii') + self.mode = self.mode.encode('ascii') ptype = None if self.opcode == 1: ptype = "RRQ" @@ -126,23 +131,23 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): for key in self.options: log.debug(" Option %s = %s", key, self.options[key]) - format = "!H" - format += "%dsx" % len(self.filename) - if self.mode == "octet": - format += "5sx" + format = b"!H" + format += b"%dsx" % len(self.filename) + if self.mode == b"octet": + format += b"5sx" else: - raise AssertionError, "Unsupported mode: %s" % mode + raise AssertionError("Unsupported mode: %s" % self.mode) # Add options. options_list = [] - if self.options.keys() > 0: + if len(self.options.keys()) > 0: log.debug("there are options to encode") for key in self.options: # Populate the option name - format += "%dsx" % len(key) - options_list.append(key) + format += b"%dsx" % len(key) + options_list.append(key.encode('ascii')) # Populate the option value - format += "%dsx" % len(str(self.options[key])) - options_list.append(str(self.options[key])) + format += b"%dsx" % len(self.options[key].encode('ascii')) + options_list.append(self.options[key].encode('ascii')) log.debug("format is %s", format) log.debug("options_list is %s", options_list) @@ -167,7 +172,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): log.debug("in decode: about to iterate buffer counting nulls") subbuf = self.buffer[2:] for c in subbuf: - if ord(c) == 0: + if sys.version_info[0] <= 2: + c = ord(c) + if c == 0: nulls += 1 log.debug("found a null at length %d, now have %d", length, nulls) format += "%dsx" % length @@ -345,14 +352,14 @@ class TftpPacketERR(TftpPacket): self.errmsg = None # FIXME - integrate in TftpErrors references? self.errmsgs = { - 1: "File not found", - 2: "Access violation", - 3: "Disk full or allocation exceeded", - 4: "Illegal TFTP operation", - 5: "Unknown transfer ID", - 6: "File already exists", - 7: "No such user", - 8: "Failed to negotiate options" + 1: b"File not found", + 2: b"Access violation", + 3: b"Disk full or allocation exceeded", + 4: b"Illegal TFTP operation", + 5: b"Unknown transfer ID", + 6: b"File already exists", + 7: b"No such user", + 8: b"Failed to negotiate options" } def __str__(self): @@ -433,18 +440,18 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions): the options so that the session can update itself to the negotiated options.""" for name in self.options: - if options.has_key(name): + if name in options: if name == 'blksize': # We can accept anything between the min and max values. size = int(self.options[name]) if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE: log.debug("negotiated blksize of %d bytes", size) else: - raise TftpException, "blksize %s option outside allowed range" % size + raise TftpException("blksize %s option outside allowed range" % size) elif name == 'tsize': size = int(self.options[name]) if size < 0: - raise TftpException, "Negative file sizes not supported" + raise TftpException("Negative file sizes not supported") else: - raise TftpException, "Unsupported option: %s" % name + raise TftpException("Unsupported option: %s" % name) return True diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index c35bd8f..0d3037c 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -3,14 +3,15 @@ instance of the server, and then run the listen() method to listen for client requests. Logging is performed via a standard logging object set in TftpShared.""" +from __future__ import absolute_import, division, print_function, unicode_literals import socket, os, time -import select, errno +import select import threading from errno import EINTR -from TftpShared import * -from TftpPacketTypes import * -from TftpPacketFactory import TftpPacketFactory -from TftpContexts import TftpContextServer +from .TftpShared import * +from .TftpPacketTypes import * +from .TftpPacketFactory import TftpPacketFactory +from .TftpContexts import TftpContextServer class TftpServer(TftpSession): """This class implements a tftp server object. Run the listen() method to @@ -42,7 +43,6 @@ class TftpServer(TftpSession): # A dict of sessions, where each session is keyed by a string like # ip:tid for the remote end. self.sessions = {} - self.session_keys = {} # A threading event to help threads synchronize with the server # is_running state. self.is_running = threading.Event() @@ -50,18 +50,6 @@ class TftpServer(TftpSession): self.shutdown_gracefully = False self.shutdown_immediately = False - # Poll structure for the listen loop - try: - self.poll = select.epoll() - self.poll_mask = select.EPOLLIN | select.EPOLLERR - self.poll_mask_in = select.EPOLLIN - self.poll_mask_err = select.EPOLLERR | select.EPOLLHUP - except: - self.poll = select.poll() - self.poll_mask = select.POLLIN | select.POLLERR - self.poll_mask_in = select.POLLIN - self.poll_mask_err = select.POLLERR | select.POLLHUP - for name in 'dyn_file_func', 'upload_open': attr = getattr(self, name) if attr and not callable(attr): @@ -70,19 +58,19 @@ class TftpServer(TftpSession): if os.path.exists(self.root): log.debug("tftproot %s does exist", self.root) if not os.path.isdir(self.root): - raise TftpException, "The tftproot must be a directory." + raise TftpException("The tftproot must be a directory.") else: log.debug("tftproot %s is a directory", self.root) if os.access(self.root, os.R_OK): log.debug("tftproot %s is readable", self.root) else: - raise TftpException, "The tftproot must be readable" + raise TftpException("The tftproot must be readable") if os.access(self.root, os.W_OK): log.debug("tftproot %s is writable", self.root) else: log.warning("The tftproot %s is not writable" % self.root) else: - raise TftpException, "The tftproot does not exist." + raise TftpException("The tftproot does not exist.") def listen(self, listenip="", @@ -103,26 +91,20 @@ class TftpServer(TftpSession): self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock.bind((listenip, listenport)) _, self.listenport = self.sock.getsockname() - except socket.error, err: + except socket.error as err: # Reraise it for now. raise self.is_running.set() - self.poll.register(self.sock.fileno(), self.poll_mask) - log.info("Starting receive loop...") while True: log.debug("shutdown_immediately is %s", self.shutdown_immediately) log.debug("shutdown_gracefully is %s", self.shutdown_gracefully) if self.shutdown_immediately: log.warn("Shutting down now. Session count: %d" % len(self.sessions)) - self.poll.unregister(self.sock.fileno()) self.sock.close() for key in self.sessions: - fd = self.sessions[key].sock.fileno() - self.poll.unregister(fd) - del self.session_keys[fd] self.sessions[key].end() self.sessions = [] break @@ -130,39 +112,34 @@ class TftpServer(TftpSession): elif self.shutdown_gracefully: if not self.sessions: log.warn("In graceful shutdown mode and all sessions complete.") - self.poll.unregister(self.sock.fileno()) self.sock.close() break + # Build the inputlist array of sockets to select() on. + inputlist = [] + inputlist.append(self.sock) + for key in self.sessions: + inputlist.append(self.sessions[key].sock) + # Block until some socket has input on it. + log.debug("Performing select on this inputlist: %s", inputlist) try: - log.debug("Performing poll with timeout %s", SOCK_TIMEOUT) - events = self.poll.poll(SOCK_TIMEOUT * 1000) - except select.error, (err, _): - if err != errno.EAGAIN and err != errno.EINTR: - log.error("poll failed with: %d", err) - self.shutdown_immediately = True - continue - except IOError, e: - if e.errno != errno.EINTR: + readyinput, readyoutput, readyspecial = \ + select.select(inputlist, [], [], SOCK_TIMEOUT) + except select.error as err: + if err[0] == EINTR: + # Interrupted system call + log.debug("Interrupted syscall, retrying") + continue + else: raise - events = [] deletion_list = [] - log.debug("Woke up with events: %s", events) # Handle the available data, if any. Maybe we timed-out. - for readysock, event in events: - if event & self.poll_mask_err: - log.error("poll received error or HUP: %d", err) - self.shutdown_immediately = True - continue - elif not (event & self.poll_mask_in): - log.warn("poll received bad event: %x", event) - continue - + for readysock in readyinput: # Is the traffic on the main server socket? ie. new session? - if readysock == self.sock.fileno(): + if readysock == self.sock: log.debug("Data ready on our main socket") buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) @@ -176,7 +153,7 @@ class TftpServer(TftpSession): # which should safely work through NAT. key = "%s:%s" % (raddress, rport) - if not self.sessions.has_key(key): + if not key in self.sessions: log.debug("Creating new server context for " "session key = %s", key) self.sessions[key] = TftpContextServer(raddress, @@ -187,10 +164,7 @@ class TftpServer(TftpSession): self.upload_open) try: self.sessions[key].start(buffer) - fd = self.sessions[key].sock.fileno() - self.poll.register(fd, self.poll_mask) - self.session_keys[fd] = key - except TftpException, err: + except TftpException as err: deletion_list.append(key) log.error("Fatal exception thrown from " "session %s: %s" % (key, str(err))) @@ -203,23 +177,25 @@ class TftpServer(TftpSession): else: # Must find the owner of this traffic. - try: - key = self.session_keys[readysock] - log.info("Matched input to session key %s" % key) - try: - self.sessions[key].cycle() - if self.sessions[key].state == None: - log.info("Successful transfer.") + for key in self.sessions: + if readysock == self.sessions[key].sock: + log.debug("Matched input to session key %s" + % key) + try: + self.sessions[key].cycle() + if self.sessions[key].state == None: + log.info("Successful transfer.") + deletion_list.append(key) + except TftpException as err: deletion_list.append(key) - except TftpException, err: - deletion_list.append(key) - log.error("Fatal exception thrown from " - "session %s: %s" - % (key, str(err))) - # Break out of for loop since we found the correct - # session. - break - except KeyError: + log.error("Fatal exception thrown from " + "session %s: %s" + % (key, str(err))) + # Break out of for loop since we found the correct + # session. + break + + else: log.error("Can't find the owner for this packet. " "Discarding.") @@ -228,7 +204,7 @@ class TftpServer(TftpSession): for key in self.sessions: try: self.sessions[key].checkTimeout(now) - except TftpTimeout, err: + except TftpTimeout as err: log.error(str(err)) self.sessions[key].retry_count += 1 if self.sessions[key].retry_count >= TIMEOUT_RETRIES: @@ -243,11 +219,8 @@ class TftpServer(TftpSession): for key in deletion_list: log.info('') log.info("Session %s complete" % key) - if self.sessions.has_key(key): + if key in self.sessions: log.debug("Gathering up metrics from session before deleting") - fd = self.sessions[key].sock.fileno() - self.poll.unregister(fd) - del self.session_keys[fd] self.sessions[key].end() metrics = self.sessions[key].metrics if metrics.duration == 0: diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py index 9c33346..6252ebd 100644 --- a/tftpy/TftpShared.py +++ b/tftpy/TftpShared.py @@ -1,5 +1,6 @@ """This module holds all objects shared by all other modules in tftpy.""" +from __future__ import absolute_import, division, print_function, unicode_literals import logging from logging.handlers import RotatingFileHandler @@ -55,7 +56,7 @@ def tftpassert(condition, msg): with the message passed. This just makes the code throughout cleaner by refactoring.""" if not condition: - raise TftpException, msg + raise TftpException(msg) def setLogLevel(level): """This function is a utility function for setting the internal log level. diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 25e21d4..fe75f89 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -8,8 +8,9 @@ the next packet in the transfer, and returns a state object until the transfer is complete, at which point it returns None. That is, unless there is a fatal error, in which case a TftpException is returned instead.""" -from TftpShared import * -from TftpPacketTypes import * +from __future__ import absolute_import, division, print_function, unicode_literals +from .TftpShared import * +from .TftpPacketTypes import * import os ############################################################################### @@ -28,7 +29,7 @@ class TftpState(object): def handle(self, pkt, raddress, rport): """An abstract method for handling a packet. It is expected to return a TftpState object, either itself or a new state.""" - raise NotImplementedError, "Abstract method" + raise NotImplementedError("Abstract method") def handleOACK(self, pkt): """This method handles an OACK from the server, syncing any accepted @@ -42,9 +43,9 @@ class TftpState(object): log.info(" %s = %s" % (key, self.context.options[key])) else: log.error("Failed to negotiate options") - raise TftpException, "Failed to negotiate options" + raise TftpException("Failed to negotiate options") else: - raise TftpException, "No options found in OACK" + raise TftpException("No options found in OACK") def returnSupportedOptions(self, options): """This method takes a requested options list from a client, and @@ -181,7 +182,7 @@ class TftpState(object): if pkt.blocknumber == 0: log.warn("There is no block zero!") self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "There is no block zero!" + raise TftpException("There is no block zero!") log.warn("Dropping duplicate block %d" % pkt.blocknumber) self.context.metrics.add_dup(pkt) log.debug("ACKing block %d again, just in case", pkt.blocknumber) @@ -192,7 +193,7 @@ class TftpState(object): msg = "Whoa! Received future block %d but expected %d" \ % (pkt.blocknumber, self.context.next_block) log.error(msg) - raise TftpException, msg + raise TftpException(msg) # Default is to ack return TftpStateExpectDAT(self.context) @@ -230,9 +231,9 @@ class TftpServerState(TftpState): # FIXME - only octet mode is supported at this time. if pkt.mode != 'octet': - self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, \ - "Only octet transfers are supported at this time." + #self.sendError(TftpErrors.IllegalTftpOp) + #raise TftpException("Only octet transfers are supported at this time.") + log.warning("Received non-octet mode request. I'll reply with binary data.") # test host/port of client end if self.context.host != raddress or self.context.port != rport: @@ -260,11 +261,10 @@ class TftpServerState(TftpState): # begin with a '/' strip it off as otherwise os.path.join will # treat it as absolute (regardless of whether it is ntpath or # posixpath module - if pkt.filename.startswith(self.context.root): + if pkt.filename.startswith(self.context.root.encode()): full_path = pkt.filename else: - full_path = os.path.join( - self.context.root, pkt.filename.lstrip('/')) + full_path = os.path.join(self.context.root, pkt.filename.decode().lstrip('/')) # Use abspath to eliminate any remaining relative elements # (e.g. '..') and ensure that is still within the server's @@ -276,7 +276,7 @@ class TftpServerState(TftpState): else: log.warn("requested file is not within the server root - bad") self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "bad file path" + raise TftpException("bad file path") self.context.file_to_transfer = pkt.filename @@ -299,16 +299,16 @@ class TftpStateServerRecvRRQ(TftpServerState): elif self.context.dyn_file_func: log.debug("No such file %s but using dyn_file_func", path) self.context.fileobj = \ - self.context.dyn_file_func(self.context.file_to_transfer) + self.context.dyn_file_func(self.context.file_to_transfer, raddress=raddress, rport=rport) if self.context.fileobj is None: log.debug("dyn_file_func returned 'None', treating as " "FileNotFound") self.sendError(TftpErrors.FileNotFound) - raise TftpException, "File not found: %s" % path + raise TftpException("File not found: %s" % path) else: self.sendError(TftpErrors.FileNotFound) - raise TftpException, "File not found: %s" % path + raise TftpException("File not found: %s" % path) # Options negotiation. if sendoack and self.context.options.has_key('tsize'): @@ -357,7 +357,7 @@ class TftpStateServerRecvWRQ(TftpServerState): if os.path.isdir(current): log.debug("%s is already an existing directory", current) else: - os.mkdir(current, 0700) + os.mkdir(current) def handle(self, pkt, raddress, rport): "Handle an initial WRQ packet as a server." @@ -419,8 +419,7 @@ class TftpStateServerStart(TftpState): rport) else: self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, \ - "Invalid packet to begin up/download: %s" % pkt + raise TftpException("Invalid packet to begin up/download: %s" % pkt) class TftpStateExpectACK(TftpState): """This class represents the state of the transfer when a DAT was just @@ -430,7 +429,7 @@ class TftpStateExpectACK(TftpState): def handle(self, pkt, raddress, rport): "Handle a packet, hopefully an ACK since we just sent a DAT." if isinstance(pkt, TftpPacketACK): - log.info("Received ACK for packet %d" % pkt.blocknumber) + log.debug("Received ACK for packet %d" % pkt.blocknumber) # Is this an ack to the one we just sent? if self.context.next_block == pkt.blocknumber: if self.context.pending_complete: @@ -455,8 +454,7 @@ class TftpStateExpectACK(TftpState): return self elif isinstance(pkt, TftpPacketERR): log.error("Received ERR packet from peer: %s" % str(pkt)) - raise TftpException, \ - "Received ERR packet from peer: %s" % str(pkt) + raise TftpException("Received ERR packet from peer: %s" % str(pkt)) else: log.warn("Discarding unsupported packet: %s" % str(pkt)) return self @@ -472,19 +470,19 @@ class TftpStateExpectDAT(TftpState): elif isinstance(pkt, TftpPacketACK): # Umm, we ACK, you don't. self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ACK from peer when expecting DAT" + raise TftpException("Received ACK from peer when expecting DAT") elif isinstance(pkt, TftpPacketWRQ): self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received WRQ from peer when expecting DAT" + raise TftpException("Received WRQ from peer when expecting DAT") elif isinstance(pkt, TftpPacketERR): self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ERR from peer: " + str(pkt) + raise TftpException("Received ERR from peer: " + str(pkt)) else: self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received unknown packet type from peer: " + str(pkt) + raise TftpException("Received unknown packet type from peer: " + str(pkt)) class TftpStateSentWRQ(TftpState): """Just sent an WRQ packet for an upload.""" @@ -527,19 +525,19 @@ class TftpStateSentWRQ(TftpState): elif isinstance(pkt, TftpPacketERR): self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ERR from server: " + str(pkt) + raise TftpException("Received ERR from server: " + str(pkt)) elif isinstance(pkt, TftpPacketRRQ): self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received RRQ from server while in upload" + raise TftpException("Received RRQ from server while in upload") elif isinstance(pkt, TftpPacketDAT): self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received DAT from server while in upload" + raise TftpException("Received DAT from server while in upload") else: self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received unknown packet type from server: " + str(pkt) + raise TftpException("Received unknown packet type from server: " + str(pkt)) # By default, no state change. return self @@ -557,7 +555,7 @@ class TftpStateSentRRQ(TftpState): log.info("Received OACK from server") try: self.handleOACK(pkt) - except TftpException, err: + except TftpException as err: log.error("Failed to negotiate options: %s" % str(err)) self.sendError(TftpErrors.FailedNegotiation) raise @@ -582,19 +580,19 @@ class TftpStateSentRRQ(TftpState): elif isinstance(pkt, TftpPacketACK): # Umm, we ACK, the server doesn't. self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ACK from server while in download" + raise TftpException("Received ACK from server while in download") elif isinstance(pkt, TftpPacketWRQ): self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received WRQ from server while in download" + raise TftpException("Received WRQ from server while in download") elif isinstance(pkt, TftpPacketERR): self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ERR from server: " + str(pkt) + raise TftpException("Received ERR from server: " + str(pkt)) else: self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received unknown packet type from server: " + str(pkt) + raise TftpException("Received unknown packet type from server: " + str(pkt)) # By default, no state change. return self diff --git a/tftpy/__init__.py b/tftpy/__init__.py index fba9a9f..33f988c 100644 --- a/tftpy/__init__.py +++ b/tftpy/__init__.py @@ -8,18 +8,19 @@ As a client of tftpy, this is the only module that you should need to import directly. The TftpClient and TftpServer classes can be reached through it. """ +from __future__ import absolute_import, division, print_function, unicode_literals import sys # Make sure that this is at least Python 2.3 required_version = (2, 3) if sys.version_info < required_version: - raise ImportError, "Requires at least Python 2.3" + raise ImportError("Requires at least Python 2.3") -from tftpy.TftpShared import * -from tftpy.TftpPacketTypes import * -from tftpy.TftpPacketFactory import * -from tftpy.TftpClient import * -from tftpy.TftpServer import * -from tftpy.TftpContexts import * -from tftpy.TftpStates import * +from .TftpShared import * +from .TftpPacketTypes import * +from .TftpPacketFactory import * +from .TftpClient import * +from .TftpServer import * +from .TftpContexts import * +from .TftpStates import * |