summaryrefslogtreecommitdiff
path: root/tftpy/TftpStates.py
diff options
context:
space:
mode:
Diffstat (limited to 'tftpy/TftpStates.py')
-rw-r--r--tftpy/TftpStates.py94
1 files changed, 47 insertions, 47 deletions
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py
index 142029c..5362fa1 100644
--- a/tftpy/TftpStates.py
+++ b/tftpy/TftpStates.py
@@ -9,9 +9,9 @@ the next packet in the transfer, and returns a state object until the transfer
is complete, at which point it returns None. That is, unless there is a fatal
error, in which case a TftpException is returned instead."""
-from TftpShared import *
-from TftpPacketTypes import *
-from TftpPacketFactory import *
+from .TftpShared import *
+from .TftpPacketTypes import *
+from .TftpPacketFactory import *
import socket, time, os, sys
###############################################################################
@@ -53,7 +53,7 @@ class TftpMetrics(object):
def add_dup(self, blocknumber):
"""This method adds a dup for a block number to the metrics."""
log.debug("Recording a dup for block %d" % blocknumber)
- if self.dups.has_key(blocknumber):
+ if blocknumber in self.dups:
self.dups[blocknumber] += 1
else:
self.dups[blocknumber] = 1
@@ -67,7 +67,7 @@ class TftpMetrics(object):
class TftpContext(object):
"""The base class of the contexts."""
- def __init__(self, host, port, timeout, dyn_file_func=None):
+ def __init__(self, host, port, timeout, dyn_file_func=None, write_mode=TftpServerWriteMode.Overwrite):
"""Constructor for the base context, setting shared instance
variables."""
self.file_to_transfer = None
@@ -94,6 +94,7 @@ class TftpContext(object):
# The last DAT packet we sent, if applicable, to make resending easy.
self.last_dat_pkt = None
self.dyn_file_func = dyn_file_func
+ self.write_mode = write_mode
def __del__(self):
"""Simple destructor to try to call housekeeping in the end method if
@@ -105,17 +106,17 @@ class TftpContext(object):
"""Compare current time with last_update time, and raise an exception
if we're over SOCK_TIMEOUT time."""
if now - self.last_update > SOCK_TIMEOUT:
- raise TftpException, "Timeout waiting for traffic"
+ raise TftpException("Timeout waiting for traffic")
def start(self):
- raise NotImplementedError, "Abstract method"
+ raise NotImplementedError("Abstract method")
def end(self):
"""Perform session cleanup, since the end method should always be
called explicitely by the calling code, this works better than the
destructor."""
log.debug("in TftpContext.end")
- if self.fileobj is not None and not self.fileobj.closed:
+ if not self.fileobj.closed:
log.debug("self.fileobj is open - closing")
self.fileobj.close()
@@ -132,7 +133,7 @@ class TftpContext(object):
host = property(gethost, sethost)
def setNextBlock(self, block):
- if block >= 2 ** 16:
+ if block > 2 ** 16:
log.debug("Block number rollover to 0 again")
block = 0
self.__eblock = block
@@ -151,12 +152,12 @@ class TftpContext(object):
log.debug("In cycle, receive attempt %d" % i)
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
- except socket.timeout, err:
+ except socket.timeout as err:
log.warn("Timeout waiting for traffic, retrying...")
continue
break
else:
- raise TftpException, "Hit max timeouts, giving up."
+ raise TftpException("Hit max timeouts, giving up.")
# Ok, we've received a packet. Log it.
log.debug("Received %d bytes from %s:%s"
@@ -190,18 +191,20 @@ class TftpContext(object):
class TftpContextServer(TftpContext):
"""The context for the server."""
- def __init__(self, host, port, timeout, root, dyn_file_func=None):
+ def __init__(self, host, port, timeout, root, dyn_file_func=None, write_mode=TftpServerWriteMode.Overwrite):
TftpContext.__init__(self,
host,
port,
timeout,
- dyn_file_func
+ dyn_file_func,
+ write_mode
)
# At this point we have no idea if this is a download or an upload. We
# need to let the start state determine that.
self.state = TftpStateServerStart(self)
self.root = root
self.dyn_file_func = dyn_file_func
+ self.write_mode = write_mode
def __str__(self):
return "%s:%s %s" % (self.host, self.port, self.state)
@@ -372,12 +375,12 @@ class TftpState(object):
def handle(self, pkt, raddress, rport):
"""An abstract method for handling a packet. It is expected to return
a TftpState object, either itself or a new state."""
- raise NotImplementedError, "Abstract method"
+ raise NotImplementedError("Abstract method")
def handleOACK(self, pkt):
"""This method handles an OACK from the server, syncing any accepted
options."""
- if pkt.options.keys() > 0:
+ if len(list(pkt.options.keys())) > 0:
if pkt.match_options(self.context.options):
log.info("Successful negotiation of options")
# Set options to OACK options
@@ -386,9 +389,9 @@ class TftpState(object):
log.info(" %s = %s" % (key, self.context.options[key]))
else:
log.error("Failed to negotiate options")
- raise TftpException, "Failed to negotiate options"
+ raise TftpException("Failed to negotiate options")
else:
- raise TftpException, "No options found in OACK"
+ raise TftpException("No options found in OACK")
def returnSupportedOptions(self, options):
"""This method takes a requested options list from a client, and
@@ -441,8 +444,7 @@ class TftpState(object):
# FIXME - only octet mode is supported at this time.
if pkt.mode != 'octet':
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, \
- "Only octet transfers are supported at this time."
+ raise TftpException("Only octet transfers are supported at this time.")
# test host/port of client end
if self.context.host != raddress or self.context.port != rport:
@@ -462,7 +464,7 @@ class TftpState(object):
# FIXME: Should we allow subdirectories?
if pkt.filename.find(os.sep) >= 0:
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "%s found in filename, not permitted" % os.sep
+ raise TftpException("%s found in filename, not permitted" % os.sep)
self.context.file_to_transfer = pkt.filename
@@ -560,7 +562,7 @@ class TftpState(object):
if pkt.blocknumber == 0:
log.warn("There is no block zero!")
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "There is no block zero!"
+ raise TftpException("There is no block zero!")
log.warn("Dropping duplicate block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt.blocknumber)
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
@@ -571,7 +573,7 @@ class TftpState(object):
msg = "Whoa! Received future block %d but expected %d" \
% (pkt.blocknumber, self.context.next_block)
log.error(msg)
- raise TftpException, msg
+ raise TftpException(msg)
# Default is to ack
return TftpStateExpectDAT(self.context)
@@ -593,15 +595,9 @@ class TftpStateServerRecvRRQ(TftpState):
log.debug("No such file %s but using dyn_file_func" % path)
self.context.fileobj = \
self.context.dyn_file_func(self.context.file_to_transfer)
-
- if self.context.fileobj is None:
- log.debug("dyn_file_func returned 'None', treating as "
- "FileNotFound")
- self.sendError(TftpErrors.FileNotFound)
- raise TftpException, "File not found: %s" % path
else:
self.sendError(TftpErrors.FileNotFound)
- raise TftpException, "File not found: %s" % path
+ raise TftpException("File not found: %s" % path)
# Options negotiation.
if sendoack:
@@ -629,6 +625,12 @@ class TftpStateServerRecvWRQ(TftpState):
log.debug("In TftpStateServerRecvWRQ.handle")
sendoack = self.serverInitial(pkt, raddress, rport)
path = self.context.root + os.sep + self.context.file_to_transfer
+ if (self.context.write_mode == TftpServerWriteMode.WriteNew) and os.path.exists(path):
+ self.sendError(TftpErrors.FileAlreadyExists)
+ raise TftpException("File already exists: %s" % path)
+ elif self.context.write_mode == TftpServerWriteMode.DenyWrite:
+ self.sendError(TftpErrors.DiskFull)
+ raise TftpException("Configured for read only operation: %s" % path)
log.info("Opening file %s for writing" % path)
if os.path.exists(path):
# FIXME: correct behavior?
@@ -674,8 +676,7 @@ class TftpStateServerStart(TftpState):
rport)
else:
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, \
- "Invalid packet to begin up/download: %s" % pkt
+ raise TftpException("Invalid packet to begin up/download: %s" % pkt)
class TftpStateExpectACK(TftpState):
"""This class represents the state of the transfer when a DAT was just
@@ -708,8 +709,7 @@ class TftpStateExpectACK(TftpState):
return self
elif isinstance(pkt, TftpPacketERR):
log.error("Received ERR packet from peer: %s" % str(pkt))
- raise TftpException, \
- "Received ERR packet from peer: %s" % str(pkt)
+ raise TftpException("Received ERR packet from peer: %s" % str(pkt))
else:
log.warn("Discarding unsupported packet: %s" % str(pkt))
return self
@@ -725,19 +725,19 @@ class TftpStateExpectDAT(TftpState):
elif isinstance(pkt, TftpPacketACK):
# Umm, we ACK, you don't.
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received ACK from peer when expecting DAT"
+ raise TftpException("Received ACK from peer when expecting DAT")
elif isinstance(pkt, TftpPacketWRQ):
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received WRQ from peer when expecting DAT"
+ raise TftpException("Received WRQ from peer when expecting DAT")
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received ERR from peer: " + str(pkt)
+ raise TftpException("Received ERR from peer: " + str(pkt))
else:
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received unknown packet type from peer: " + str(pkt)
+ raise TftpException("Received unknown packet type from peer: " + str(pkt))
class TftpStateSentWRQ(TftpState):
"""Just sent an WRQ packet for an upload."""
@@ -753,7 +753,7 @@ class TftpStateSentWRQ(TftpState):
log.info("Received OACK from server")
try:
self.handleOACK(pkt)
- except TftpException, err:
+ except TftpException as err:
log.error("Failed to negotiate options")
self.sendError(TftpErrors.FailedNegotiation)
raise
@@ -780,19 +780,19 @@ class TftpStateSentWRQ(TftpState):
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received ERR from server: " + str(pkt)
+ raise TftpException("Received ERR from server: " + str(pkt))
elif isinstance(pkt, TftpPacketRRQ):
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received RRQ from server while in upload"
+ raise TftpException("Received RRQ from server while in upload")
elif isinstance(pkt, TftpPacketDAT):
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received DAT from server while in upload"
+ raise TftpException("Received DAT from server while in upload")
else:
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received unknown packet type from server: " + str(pkt)
+ raise TftpException("Received unknown packet type from server: " + str(pkt))
# By default, no state change.
return self
@@ -810,7 +810,7 @@ class TftpStateSentRRQ(TftpState):
log.info("Received OACK from server")
try:
self.handleOACK(pkt)
- except TftpException, err:
+ except TftpException as err:
log.error("Failed to negotiate options: %s" % str(err))
self.sendError(TftpErrors.FailedNegotiation)
raise
@@ -835,19 +835,19 @@ class TftpStateSentRRQ(TftpState):
elif isinstance(pkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received ACK from server while in download"
+ raise TftpException("Received ACK from server while in download")
elif isinstance(pkt, TftpPacketWRQ):
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received WRQ from server while in download"
+ raise TftpException("Received WRQ from server while in download")
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received ERR from server: " + str(pkt)
+ raise TftpException("Received ERR from server: " + str(pkt))
else:
self.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received unknown packet type from server: " + str(pkt)
+ raise TftpException("Received unknown packet type from server: " + str(pkt))
# By default, no state change.
return self