summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael P. Soulier <msoulier@digitaltorque.ca>2011-07-24 17:37:16 -0400
committerMichael P. Soulier <msoulier@digitaltorque.ca>2011-07-24 17:37:16 -0400
commit04aaa2ef9ff6a09d39d67a1ee42b359e244afd24 (patch)
tree6b2ec2ded5910e1964293b82be20cc1c35b5275f
parent40977c6f74496be16087767b8444af2b34f933d5 (diff)
downloadtftpy-04aaa2ef9ff6a09d39d67a1ee42b359e244afd24.tar.gz
Fixing issue #3, expanding unit tests.
-rw-r--r--t/test.py67
-rw-r--r--tftpy/TftpClient.py6
-rw-r--r--tftpy/TftpContexts.py16
-rw-r--r--tftpy/TftpShared.py3
-rw-r--r--tftpy/TftpStates.py55
5 files changed, 109 insertions, 38 deletions
diff --git a/t/test.py b/t/test.py
index 1052db1..d044ac8 100644
--- a/t/test.py
+++ b/t/test.py
@@ -4,6 +4,7 @@ import unittest
import logging
import tftpy
import os
+import time
log = tftpy.log
@@ -140,6 +141,72 @@ class TestTftpyState(unittest.TestCase):
def setUp(self):
tftpy.setLogLevel(logging.DEBUG)
+ def clientServerUploadOptions(self, options):
+ """Fire up a client and a server and do an upload."""
+ root = '/tmp'
+ home = os.path.dirname(os.path.abspath(__file__))
+ filename = '100KBFILE'
+ input_path = os.path.join(home, filename)
+ server = tftpy.TftpServer(root)
+ client = tftpy.TftpClient('localhost',
+ 20001,
+ options)
+ # Fork a server and run the client in this process.
+ child_pid = os.fork()
+ if child_pid:
+ # parent - let the server start
+ try:
+ time.sleep(1)
+ client.upload(filename,
+ input_path)
+ finally:
+ os.kill(child_pid, 15)
+ os.waitpid(child_pid, 0)
+
+ else:
+ server.listen('localhost', 20001)
+
+ def clientServerDownloadOptions(self, options):
+ """Fire up a client and a server and do a download."""
+ root = os.path.dirname(os.path.abspath(__file__))
+ server = tftpy.TftpServer(root)
+ client = tftpy.TftpClient('localhost',
+ 20001,
+ options)
+ # Fork a server and run the client in this process.
+ child_pid = os.fork()
+ if child_pid:
+ # parent - let the server start
+ try:
+ time.sleep(1)
+ client.download('100KBFILE',
+ '/tmp/out')
+ finally:
+ os.kill(child_pid, 15)
+ os.waitpid(child_pid, 0)
+
+ else:
+ server.listen('localhost', 20001)
+
+ def testClientServerNoOptions(self):
+ self.clientServerDownloadOptions({})
+
+ def testClientServerBlksize(self):
+ for blksize in [512, 1024, 2048, 4096]:
+ self.clientServerDownloadOptions({'blksize': blksize})
+
+ def testClientServerUploadNoOptions(self):
+ self.clientServerUploadOptions({})
+
+ def testClientServerUploadOptions(self):
+ for blksize in [512, 1024, 2048, 4096]:
+ self.clientServerUploadOptions({'blksize': blksize})
+
+ def testClientServerNoOptionsDelay(self):
+ tftpy.TftpStates.DELAY_BLOCK = 10
+ self.clientServerDownloadOptions({})
+ tftpy.TftpStates.DELAY_BLOCK = 0
+
def testServerNoOptions(self):
"""Test the server states."""
raddress = '127.0.0.2'
diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py
index 16753f1..e9e46e7 100644
--- a/tftpy/TftpClient.py
+++ b/tftpy/TftpClient.py
@@ -74,10 +74,10 @@ class TftpClient(TftpSession):
setting, which is the amount of time that the client will wait for a
DAT packet to be ACKd by the server.
+ The input option is the full path to the file to upload, which can
+ optionally be '-' to read from stdin.
+
Note: If output is a hyphen then stdout is used."""
- # Open the input file.
- # FIXME: As of the state machine, this is now broken. Need to
- # implement with new state machine.
self.context = TftpContextClientUpload(self.host,
self.iport,
filename,
diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py
index a76b686..c3a1bd4 100644
--- a/tftpy/TftpContexts.py
+++ b/tftpy/TftpContexts.py
@@ -50,15 +50,15 @@ class TftpMetrics(object):
for key in self.dups:
self.dupcount += self.dups[key]
- def add_dup(self, blocknumber):
- """This method adds a dup for a block number to the metrics."""
- log.debug("Recording a dup for block %d" % blocknumber)
- if self.dups.has_key(blocknumber):
- self.dups[blocknumber] += 1
+ def add_dup(self, pkt):
+ """This method adds a dup for a packet to the metrics."""
+ log.debug("Recording a dup of %s" % pkt)
+ s = str(pkt)
+ if self.dups.has_key(s):
+ self.dups[s] += 1
else:
- self.dups[blocknumber] = 1
- tftpassert(self.dups[blocknumber] < MAX_DUPS,
- "Max duplicates for block %d reached" % blocknumber)
+ self.dups[s] = 1
+ tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")
###############################################################################
# Context classes
diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py
index 1039ed2..d09d8bd 100644
--- a/tftpy/TftpShared.py
+++ b/tftpy/TftpShared.py
@@ -11,6 +11,9 @@ MAX_DUPS = 20
TIMEOUT_RETRIES = 5
DEF_TFTP_PORT = 69
+# A hook for deliberately introducing delay in testing.
+DELAY_BLOCK = 0
+
# Initialize the logger.
logging.basicConfig()
# The logger used by this library. Feel free to clobber it with your own, if you like, as
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py
index 716220a..c106220 100644
--- a/tftpy/TftpStates.py
+++ b/tftpy/TftpStates.py
@@ -124,30 +124,29 @@ class TftpState(object):
return sendoack
- def sendDAT(self, resend=False):
+ def sendDAT(self):
"""This method sends the next DAT packet based on the data in the
context. It returns a boolean indicating whether the transfer is
finished."""
finished = False
blocknumber = self.context.next_block
+ # Test hook
+ if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
+ import time
+ log.debug("Deliberately delaying 10 seconds...")
+ time.sleep(10)
tftpassert( blocknumber > 0, "There is no block zero!" )
dat = None
- if resend:
- log.warn("Resending block number %d" % blocknumber)
- dat = self.context.last_pkt
- self.context.metrics.resent_bytes += len(dat.data)
- self.context.metrics.add_dup(dat)
- else:
- blksize = self.context.getBlocksize()
- buffer = self.context.fileobj.read(blksize)
- log.debug("Read %d bytes into buffer" % len(buffer))
- if len(buffer) < blksize:
- log.info("Reached EOF on file %s"
- % self.context.file_to_transfer)
- finished = True
- dat = TftpPacketDAT()
- dat.data = buffer
- dat.blocknumber = blocknumber
+ blksize = self.context.getBlocksize()
+ buffer = self.context.fileobj.read(blksize)
+ log.debug("Read %d bytes into buffer" % len(buffer))
+ if len(buffer) < blksize:
+ log.info("Reached EOF on file %s"
+ % self.context.file_to_transfer)
+ finished = True
+ dat = TftpPacketDAT()
+ dat.data = buffer
+ dat.blocknumber = blocknumber
self.context.metrics.bytes += len(dat.data)
log.debug("Sending DAT packet %d" % dat.blocknumber)
self.context.sock.sendto(dat.encode().buffer,
@@ -170,7 +169,7 @@ class TftpState(object):
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
- self.last_pkt = ackpkt
+ self.context.last_pkt = ackpkt
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
@@ -181,7 +180,7 @@ class TftpState(object):
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
- self.last_pkt = errpkt
+ self.context.last_pkt = errpkt
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
@@ -192,18 +191,18 @@ class TftpState(object):
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
- self.last_pkt = pkt
+ self.context.last_pkt = pkt
def resendLast(self):
"Resend the last sent packet due to a timeout."
log.warn("Resending packet %s on sessions %s"
- % (self.last_pkt, self))
- self.context.metrics.resent_bytes += len(self.last_pkt.data)
- self.context.metrics.add_dup(self.last_pkt)
- self.context.sock.sendto(self.last_pkt.encode().buffer,
+ % (self.context.last_pkt, self))
+ self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
+ self.context.metrics.add_dup(self.context.last_pkt)
+ self.context.sock.sendto(self.context.last_pkt.encode().buffer,
(self.context.host, self.context.tidport))
if self.context.packethook:
- self.context.packethook(self.last_pkt)
+ self.context.packethook(self.context.last_pkt)
def handleDat(self, pkt):
"""This method handles a DAT packet during a client download, or a
@@ -232,7 +231,7 @@ class TftpState(object):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "There is no block zero!"
log.warn("Dropping duplicate block %d" % pkt.blocknumber)
- self.context.metrics.add_dup(pkt.blocknumber)
+ self.context.metrics.add_dup(pkt)
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
self.sendACK(pkt.blocknumber)
@@ -369,7 +368,9 @@ class TftpStateExpectACK(TftpState):
self.context.pending_complete = self.sendDAT()
elif pkt.blocknumber < self.context.next_block:
- self.context.metrics.add_dup(pkt.blocknumber)
+ log.debug("Received duplicate ACK for block %d"
+ % pkt.blocknumber)
+ self.context.metrics.add_dup(pkt)
else:
log.warn("Oooh, time warp. Received ACK to packet we "