diff options
60 files changed, 6244 insertions, 6193 deletions
diff --git a/contrib/async-test/test-leaf.py b/contrib/async-test/test-leaf.py index 8b7c3e3f5..4ea4a9b8c 100755 --- a/contrib/async-test/test-leaf.py +++ b/contrib/async-test/test-leaf.py @@ -7,16 +7,17 @@ from thrift.protocol import TBinaryProtocol from thrift.server import THttpServer from aggr import Aggr + class AggrHandler(Aggr.Iface): - def __init__(self): - self.values = [] + def __init__(self): + self.values = [] - def addValue(self, value): - self.values.append(value) + def addValue(self, value): + self.values.append(value) - def getValues(self, ): - time.sleep(1) - return self.values + def getValues(self, ): + time.sleep(1) + return self.values processor = Aggr.Processor(AggrHandler()) pfactory = TBinaryProtocol.TBinaryProtocolFactory() diff --git a/contrib/fb303/py/fb303/FacebookBase.py b/contrib/fb303/py/fb303/FacebookBase.py index 685ff20f3..07db10cd3 100644 --- a/contrib/fb303/py/fb303/FacebookBase.py +++ b/contrib/fb303/py/fb303/FacebookBase.py @@ -24,59 +24,60 @@ import FacebookService import thrift.reflection.limited from ttypes import fb_status + class FacebookBase(FacebookService.Iface): - def __init__(self, name): - self.name = name - self.alive = int(time.time()) - self.counters = {} + def __init__(self, name): + self.name = name + self.alive = int(time.time()) + self.counters = {} - def getName(self, ): - return self.name + def getName(self, ): + return self.name - def getVersion(self, ): - return '' + def getVersion(self, ): + return '' - def getStatus(self, ): - return fb_status.ALIVE + def getStatus(self, ): + return fb_status.ALIVE - def getCounters(self): - return self.counters + def getCounters(self): + return self.counters - def resetCounter(self, key): - self.counters[key] = 0 + def resetCounter(self, key): + self.counters[key] = 0 - def getCounter(self, key): - if self.counters.has_key(key): - return self.counters[key] - return 0 + def getCounter(self, key): + if self.counters.has_key(key): + return self.counters[key] + return 0 - def incrementCounter(self, key): - self.counters[key] = self.getCounter(key) + 1 + def incrementCounter(self, key): + self.counters[key] = self.getCounter(key) + 1 - def setOption(self, key, value): - pass + def setOption(self, key, value): + pass - def getOption(self, key): - return "" + def getOption(self, key): + return "" - def getOptions(self): - return {} + def getOptions(self): + return {} - def getOptions(self): - return {} + def getOptions(self): + return {} - def aliveSince(self): - return self.alive + def aliveSince(self): + return self.alive - def getCpuProfile(self, duration): - return "" + def getCpuProfile(self, duration): + return "" - def getLimitedReflection(self): - return thrift.reflection.limited.Service() + def getLimitedReflection(self): + return thrift.reflection.limited.Service() - def reinitialize(self): - pass + def reinitialize(self): + pass - def shutdown(self): - pass + def shutdown(self): + pass diff --git a/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py index 4f8ce9933..4b1c25728 100644 --- a/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py +++ b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py @@ -19,7 +19,8 @@ # under the License. # -import sys, os +import sys +import os from optparse import OptionParser from thrift.Thrift import * @@ -31,11 +32,12 @@ from thrift.protocol import TBinaryProtocol from fb303 import * from fb303.ttypes import * + def service_ctrl( - command, - port, - trans_factory = None, - prot_factory = None): + command, + port, + trans_factory=None, + prot_factory=None): """ service_ctrl is a generic function to execute standard fb303 functions @@ -66,19 +68,19 @@ def service_ctrl( return 3 # scalar commands - if command in ["version","alive","name"]: + if command in ["version", "alive", "name"]: try: - result = fb303_wrapper(command, port, trans_factory, prot_factory) + result = fb303_wrapper(command, port, trans_factory, prot_factory) print result return 0 except: - print "failed to get ",command + print "failed to get ", command return 3 # counters if command in ["counters"]: try: - counters = fb303_wrapper('counters', port, trans_factory, prot_factory) + counters = fb303_wrapper('counters', port, trans_factory, prot_factory) for counter in counters: print "%s: %d" % (counter, counters[counter]) return 0 @@ -86,11 +88,10 @@ def service_ctrl( print "failed to get counters" return 3 - # Only root should be able to run the following commands if os.getuid() == 0: # async commands - if command in ["stop","reload"] : + if command in ["stop", "reload"]: try: fb303_wrapper(command, port, trans_factory, prot_factory) return 0 @@ -98,23 +99,21 @@ def service_ctrl( print "failed to tell the service to ", command return 3 else: - if command in ["stop","reload"]: + if command in ["stop", "reload"]: print "root privileges are required to stop or reload the service." return 4 print "The following commands are available:" - for command in ["counters","name","version","alive","status"]: + for command in ["counters", "name", "version", "alive", "status"]: print "\t%s" % command print "The following commands are available for users with root privileges:" - for command in ["stop","reload"]: + for command in ["stop", "reload"]: print "\t%s" % command + return 0 - return 0; - - -def fb303_wrapper(command, port, trans_factory = None, prot_factory = None): +def fb303_wrapper(command, port, trans_factory=None, prot_factory=None): sock = TSocket.TSocket('localhost', port) # use input transport factory if provided @@ -179,11 +178,11 @@ def main(): # parse command line options parser = OptionParser() - commands=["stop","counters","status","reload","version","name","alive"] + commands = ["stop", "counters", "status", "reload", "version", "name", "alive"] parser.add_option("-c", "--command", dest="command", help="execute this API", choices=commands, default="status") - parser.add_option("-p","--port",dest="port",help="the service's port", + parser.add_option("-p", "--port", dest="port", help="the service's port", default=9082) (options, args) = parser.parse_args() diff --git a/contrib/fb303/py/setup.py b/contrib/fb303/py/setup.py index 6710c8f61..4321ce258 100644 --- a/contrib/fb303/py/setup.py +++ b/contrib/fb303/py/setup.py @@ -24,26 +24,25 @@ try: from setuptools import setup, Extension except: from distutils.core import setup, Extension, Command - -setup(name = 'thrift_fb303', - version = '1.0.0-dev', - description = 'Python bindings for the Apache Thrift FB303', - author = ['Thrift Developers'], - author_email = ['dev@thrift.apache.org'], - url = 'http://thrift.apache.org', - license = 'Apache License 2.0', - packages = [ - 'fb303', - 'fb303_scripts', - ], - classifiers = [ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Topic :: Software Development :: Libraries', - 'Topic :: System :: Networking' - ], -) +setup(name='thrift_fb303', + version='1.0.0-dev', + description='Python bindings for the Apache Thrift FB303', + author=['Thrift Developers'], + author_email=['dev@thrift.apache.org'], + url='http://thrift.apache.org', + license='Apache License 2.0', + packages=[ + 'fb303', + 'fb303_scripts', + ], + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Environment :: Console', + 'Intended Audience :: Developers', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Topic :: Software Development :: Libraries', + 'Topic :: System :: Networking' + ], + ) diff --git a/contrib/parse_profiling.py b/contrib/parse_profiling.py index 3d46fb832..0be5f29ed 100755 --- a/contrib/parse_profiling.py +++ b/contrib/parse_profiling.py @@ -46,6 +46,8 @@ class AddressInfo(object): g_addrs_by_filename = {} + + def get_address(filename, address): """ Retrieve an AddressInfo object for the specified object file and address. @@ -103,12 +105,12 @@ def translate_file_addresses(filename, addresses, options): idx = file_and_line.rfind(':') if idx < 0: msg = 'expected file and line number from addr2line; got %r' % \ - (file_and_line,) + (file_and_line,) msg += '\nfile=%r, address=%r' % (filename, address.address) raise Exception(msg) address.sourceFile = file_and_line[:idx] - address.sourceLine = file_and_line[idx+1:] + address.sourceLine = file_and_line[idx + 1:] (remaining_out, cmd_err) = proc.communicate() retcode = proc.wait() @@ -180,7 +182,7 @@ def process_file(in_file, out_file, options): virt_call_regex = re.compile(r'^\s*T_VIRTUAL_CALL: (\d+) calls on (.*):$') gen_prot_regex = re.compile( - r'^\s*T_GENERIC_PROTOCOL: (\d+) calls to (.*) with a (.*):$') + r'^\s*T_GENERIC_PROTOCOL: (\d+) calls to (.*) with a (.*):$') bt_regex = re.compile(r'^\s*#(\d+)\s*(.*) \[(0x[0-9A-Za-z]+)\]$') # Parse all of the input, and store it as Entry objects @@ -209,7 +211,7 @@ def process_file(in_file, out_file, options): # "_Z" to the type name to make it look like an external name. type_name = '_Z' + type_name header = 'T_VIRTUAL_CALL: %d calls on "%s"' % \ - (num_calls, type_name) + (num_calls, type_name) if current_entry is not None: entries.append(current_entry) current_entry = Entry(header) @@ -224,7 +226,7 @@ def process_file(in_file, out_file, options): type_name1 = '_Z' + type_name1 type_name2 = '_Z' + type_name2 header = 'T_GENERIC_PROTOCOL: %d calls to "%s" with a "%s"' % \ - (num_calls, type_name1, type_name2) + (num_calls, type_name1, type_name2) if current_entry is not None: entries.append(current_entry) current_entry = Entry(header) diff --git a/contrib/zeromq/TZmqClient.py b/contrib/zeromq/TZmqClient.py index d56069733..1bd60a1e5 100644 --- a/contrib/zeromq/TZmqClient.py +++ b/contrib/zeromq/TZmqClient.py @@ -20,44 +20,45 @@ import zmq from cStringIO import StringIO from thrift.transport.TTransport import TTransportBase, CReadableTransport + class TZmqClient(TTransportBase, CReadableTransport): - def __init__(self, ctx, endpoint, sock_type): - self._sock = ctx.socket(sock_type) - self._endpoint = endpoint - self._wbuf = StringIO() - self._rbuf = StringIO() - - def open(self): - self._sock.connect(self._endpoint) - - def read(self, size): - ret = self._rbuf.read(size) - if len(ret) != 0: - return ret - self._read_message() - return self._rbuf.read(size) - - def _read_message(self): - msg = self._sock.recv() - self._rbuf = StringIO(msg) - - def write(self, buf): - self._wbuf.write(buf) - - def flush(self): - msg = self._wbuf.getvalue() - self._wbuf = StringIO() - self._sock.send(msg) - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self._rbuf - - # NOTE: This will probably not actually work. - def cstringio_refill(self, prefix, reqlen): - while len(prefix) < reqlen: - self.read_message() - prefix += self._rbuf.getvalue() - self._rbuf = StringIO(prefix) - return self._rbuf + def __init__(self, ctx, endpoint, sock_type): + self._sock = ctx.socket(sock_type) + self._endpoint = endpoint + self._wbuf = StringIO() + self._rbuf = StringIO() + + def open(self): + self._sock.connect(self._endpoint) + + def read(self, size): + ret = self._rbuf.read(size) + if len(ret) != 0: + return ret + self._read_message() + return self._rbuf.read(size) + + def _read_message(self): + msg = self._sock.recv() + self._rbuf = StringIO(msg) + + def write(self, buf): + self._wbuf.write(buf) + + def flush(self): + msg = self._wbuf.getvalue() + self._wbuf = StringIO() + self._sock.send(msg) + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._rbuf + + # NOTE: This will probably not actually work. + def cstringio_refill(self, prefix, reqlen): + while len(prefix) < reqlen: + self.read_message() + prefix += self._rbuf.getvalue() + self._rbuf = StringIO(prefix) + return self._rbuf diff --git a/contrib/zeromq/TZmqServer.py b/contrib/zeromq/TZmqServer.py index c83cc8d5d..15c1543ac 100644 --- a/contrib/zeromq/TZmqServer.py +++ b/contrib/zeromq/TZmqServer.py @@ -21,58 +21,59 @@ import zmq import thrift.server.TServer import thrift.transport.TTransport + class TZmqServer(thrift.server.TServer.TServer): - def __init__(self, processor, ctx, endpoint, sock_type): - thrift.server.TServer.TServer.__init__(self, processor, None) - self.zmq_type = sock_type - self.socket = ctx.socket(sock_type) - self.socket.bind(endpoint) + def __init__(self, processor, ctx, endpoint, sock_type): + thrift.server.TServer.TServer.__init__(self, processor, None) + self.zmq_type = sock_type + self.socket = ctx.socket(sock_type) + self.socket.bind(endpoint) - def serveOne(self): - msg = self.socket.recv() - itrans = thrift.transport.TTransport.TMemoryBuffer(msg) - otrans = thrift.transport.TTransport.TMemoryBuffer() - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) + def serveOne(self): + msg = self.socket.recv() + itrans = thrift.transport.TTransport.TMemoryBuffer(msg) + otrans = thrift.transport.TTransport.TMemoryBuffer() + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - self.processor.process(iprot, oprot) - except Exception: - logging.exception("Exception while processing request") - # Fall through and send back a response, even if empty or incomplete. + try: + self.processor.process(iprot, oprot) + except Exception: + logging.exception("Exception while processing request") + # Fall through and send back a response, even if empty or incomplete. - if self.zmq_type == zmq.REP: - msg = otrans.getvalue() - self.socket.send(msg) + if self.zmq_type == zmq.REP: + msg = otrans.getvalue() + self.socket.send(msg) - def serve(self): - while True: - self.serveOne() + def serve(self): + while True: + self.serveOne() class TZmqMultiServer(object): - def __init__(self): - self.servers = [] + def __init__(self): + self.servers = [] - def serveOne(self, timeout = -1): - self._serveActive(self._setupPoll(), timeout) + def serveOne(self, timeout=-1): + self._serveActive(self._setupPoll(), timeout) - def serveForever(self): - poll_info = self._setupPoll() - while True: - self._serveActive(poll_info, -1) + def serveForever(self): + poll_info = self._setupPoll() + while True: + self._serveActive(poll_info, -1) - def _setupPoll(self): - server_map = {} - poller = zmq.Poller() - for server in self.servers: - server_map[server.socket] = server - poller.register(server.socket, zmq.POLLIN) - return (server_map, poller) + def _setupPoll(self): + server_map = {} + poller = zmq.Poller() + for server in self.servers: + server_map[server.socket] = server + poller.register(server.socket, zmq.POLLIN) + return (server_map, poller) - def _serveActive(self, poll_info, timeout): - (server_map, poller) = poll_info - ready = dict(poller.poll()) - for sock, state in ready.items(): - assert (state & zmq.POLLIN) != 0 - server_map[sock].serveOne() + def _serveActive(self, poll_info, timeout): + (server_map, poller) = poll_info + ready = dict(poller.poll()) + for sock, state in ready.items(): + assert (state & zmq.POLLIN) != 0 + server_map[sock].serveOne() diff --git a/contrib/zeromq/test-client.py b/contrib/zeromq/test-client.py index 1886d9cab..753b132d8 100755 --- a/contrib/zeromq/test-client.py +++ b/contrib/zeromq/test-client.py @@ -9,28 +9,28 @@ import storage.Storage def main(args): - endpoint = "tcp://127.0.0.1:9090" - socktype = zmq.REQ - incr = 0 - if len(args) > 1: - incr = int(args[1]) - if incr: - socktype = zmq.DOWNSTREAM - endpoint = "tcp://127.0.0.1:9091" + endpoint = "tcp://127.0.0.1:9090" + socktype = zmq.REQ + incr = 0 + if len(args) > 1: + incr = int(args[1]) + if incr: + socktype = zmq.DOWNSTREAM + endpoint = "tcp://127.0.0.1:9091" - ctx = zmq.Context() - transport = TZmqClient.TZmqClient(ctx, endpoint, socktype) - protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocolAccelerated(transport) - client = storage.Storage.Client(protocol) - transport.open() + ctx = zmq.Context() + transport = TZmqClient.TZmqClient(ctx, endpoint, socktype) + protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocolAccelerated(transport) + client = storage.Storage.Client(protocol) + transport.open() - if incr: - client.incr(incr) - time.sleep(0.05) - else: - value = client.get() - print value + if incr: + client.incr(incr) + time.sleep(0.05) + else: + value = client.get() + print value if __name__ == "__main__": - main(sys.argv) + main(sys.argv) diff --git a/contrib/zeromq/test-server.py b/contrib/zeromq/test-server.py index 5767b71fe..c7804d317 100755 --- a/contrib/zeromq/test-server.py +++ b/contrib/zeromq/test-server.py @@ -6,28 +6,28 @@ import storage.Storage class StorageHandler(storage.Storage.Iface): - def __init__(self): - self.value = 0 + def __init__(self): + self.value = 0 - def incr(self, amount): - self.value += amount + def incr(self, amount): + self.value += amount - def get(self): - return self.value + def get(self): + return self.value def main(): - handler = StorageHandler() - processor = storage.Storage.Processor(handler) + handler = StorageHandler() + processor = storage.Storage.Processor(handler) - ctx = zmq.Context() - reqrep_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9090", zmq.REP) - oneway_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9091", zmq.UPSTREAM) - multiserver = TZmqServer.TZmqMultiServer() - multiserver.servers.append(reqrep_server) - multiserver.servers.append(oneway_server) - multiserver.serveForever() + ctx = zmq.Context() + reqrep_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9090", zmq.REP) + oneway_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9091", zmq.UPSTREAM) + multiserver = TZmqServer.TZmqMultiServer() + multiserver.servers.append(reqrep_server) + multiserver.servers.append(oneway_server) + multiserver.serveForever() if __name__ == "__main__": - main() + main() diff --git a/lib/py/setup.py b/lib/py/setup.py index 090544ce9..f57c1a131 100644 --- a/lib/py/setup.py +++ b/lib/py/setup.py @@ -9,7 +9,7 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -24,7 +24,7 @@ import sys try: from setuptools import setup, Extension except: - from distutils.core import setup, Extension, Command + from distutils.core import setup, Extension from distutils.command.build_ext import build_ext from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError @@ -41,63 +41,66 @@ if sys.platform == 'win32': else: ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError) + class BuildFailed(Exception): pass + class ve_build_ext(build_ext): def run(self): try: build_ext.run(self) - except DistutilsPlatformError as x: + except DistutilsPlatformError: raise BuildFailed() def build_extension(self, ext): try: build_ext.build_extension(self, ext) - except ext_errors as x: + except ext_errors: raise BuildFailed() + def run_setup(with_binary): if with_binary: extensions = dict( - ext_modules = [ - Extension('thrift.protocol.fastbinary', - sources = ['src/protocol/fastbinary.c'], - include_dirs = include_dirs, - ) + ext_modules=[ + Extension('thrift.protocol.fastbinary', + sources=['src/protocol/fastbinary.c'], + include_dirs=include_dirs, + ) ], cmdclass=dict(build_ext=ve_build_ext) ) else: extensions = dict() - setup(name = 'thrift', - version = '1.0.0-dev', - description = 'Python bindings for the Apache Thrift RPC system', - author = 'Thrift Developers', - author_email = 'dev@thrift.apache.org', - url = 'http://thrift.apache.org', - license = 'Apache License 2.0', - install_requires=['six>=1.7.2'], - packages = [ - 'thrift', - 'thrift.protocol', - 'thrift.transport', - 'thrift.server', - ], - package_dir = {'thrift' : 'src'}, - classifiers = [ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 3', - 'Topic :: Software Development :: Libraries', - 'Topic :: System :: Networking' - ], - **extensions - ) + setup(name='thrift', + version='1.0.0-dev', + description='Python bindings for the Apache Thrift RPC system', + author='Thrift Developers', + author_email='dev@thrift.apache.org', + url='http://thrift.apache.org', + license='Apache License 2.0', + install_requires=['six>=1.7.2'], + packages=[ + 'thrift', + 'thrift.protocol', + 'thrift.transport', + 'thrift.server', + ], + package_dir={'thrift': 'src'}, + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Environment :: Console', + 'Intended Audience :: Developers', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 3', + 'Topic :: Software Development :: Libraries', + 'Topic :: System :: Networking' + ], + **extensions + ) try: with_binary = False diff --git a/lib/py/src/TMultiplexedProcessor.py b/lib/py/src/TMultiplexedProcessor.py index a8d5565c3..581214b31 100644 --- a/lib/py/src/TMultiplexedProcessor.py +++ b/lib/py/src/TMultiplexedProcessor.py @@ -20,39 +20,36 @@ from thrift.Thrift import TProcessor, TMessageType, TException from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol + class TMultiplexedProcessor(TProcessor): - def __init__(self): - self.services = {} + def __init__(self): + self.services = {} - def registerProcessor(self, serviceName, processor): - self.services[serviceName] = processor + def registerProcessor(self, serviceName, processor): + self.services[serviceName] = processor - def process(self, iprot, oprot): - (name, type, seqid) = iprot.readMessageBegin(); - if type != TMessageType.CALL & type != TMessageType.ONEWAY: - raise TException("TMultiplex protocol only supports CALL & ONEWAY") + def process(self, iprot, oprot): + (name, type, seqid) = iprot.readMessageBegin() + if type != TMessageType.CALL & type != TMessageType.ONEWAY: + raise TException("TMultiplex protocol only supports CALL & ONEWAY") - index = name.find(TMultiplexedProtocol.SEPARATOR) - if index < 0: - raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?") + index = name.find(TMultiplexedProtocol.SEPARATOR) + if index < 0: + raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?") - serviceName = name[0:index] - call = name[index+len(TMultiplexedProtocol.SEPARATOR):] - if not serviceName in self.services: - raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?") + serviceName = name[0:index] + call = name[index + len(TMultiplexedProtocol.SEPARATOR):] + if serviceName not in self.services: + raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?") - standardMessage = ( - call, - type, - seqid - ) - return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot) + standardMessage = (call, type, seqid) + return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot) class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): - def __init__(self, protocol, messageBegin): - TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) - self.messageBegin = messageBegin + def __init__(self, protocol, messageBegin): + TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) + self.messageBegin = messageBegin - def readMessageBegin(self): - return self.messageBegin + def readMessageBegin(self): + return self.messageBegin diff --git a/lib/py/src/TSCons.py b/lib/py/src/TSCons.py index ed2601a7d..bc67d7069 100644 --- a/lib/py/src/TSCons.py +++ b/lib/py/src/TSCons.py @@ -20,18 +20,17 @@ from os import path from SCons.Builder import Builder from six.moves import map -from six.moves import zip def scons_env(env, add=''): - opath = path.dirname(path.abspath('$TARGET')) - lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' - cppbuild = Builder(action=lstr) - env.Append(BUILDERS={'ThriftCpp': cppbuild}) + opath = path.dirname(path.abspath('$TARGET')) + lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' + cppbuild = Builder(action=lstr) + env.Append(BUILDERS={'ThriftCpp': cppbuild}) def gen_cpp(env, dir, file): - scons_env(env) - suffixes = ['_types.h', '_types.cpp'] - targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) - return env.ThriftCpp(targets, dir + file + '.thrift') + scons_env(env) + suffixes = ['_types.h', '_types.cpp'] + targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) + return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py index e3b4df7b2..e01a49f25 100644 --- a/lib/py/src/TTornado.py +++ b/lib/py/src/TTornado.py @@ -18,10 +18,9 @@ # from __future__ import absolute_import +import logging import socket import struct -import logging -logger = logging.getLogger(__name__) from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer @@ -32,6 +31,8 @@ from tornado import gen, iostream, ioloop, tcpserver, concurrent __all__ = ['TTornadoServer', 'TTornadoStreamTransport'] +logger = logging.getLogger(__name__) + class _Lock(object): def __init__(self): diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py index 11ee79625..c4dabdca0 100644 --- a/lib/py/src/Thrift.py +++ b/lib/py/src/Thrift.py @@ -21,170 +21,172 @@ import sys class TType(object): - STOP = 0 - VOID = 1 - BOOL = 2 - BYTE = 3 - I08 = 3 - DOUBLE = 4 - I16 = 6 - I32 = 8 - I64 = 10 - STRING = 11 - UTF7 = 11 - STRUCT = 12 - MAP = 13 - SET = 14 - LIST = 15 - UTF8 = 16 - UTF16 = 17 - - _VALUES_TO_NAMES = ('STOP', - 'VOID', - 'BOOL', - 'BYTE', - 'DOUBLE', - None, - 'I16', - None, - 'I32', - None, - 'I64', - 'STRING', - 'STRUCT', - 'MAP', - 'SET', - 'LIST', - 'UTF8', - 'UTF16') + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + + _VALUES_TO_NAMES = ( + 'STOP', + 'VOID', + 'BOOL', + 'BYTE', + 'DOUBLE', + None, + 'I16', + None, + 'I32', + None, + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16', + ) class TMessageType(object): - CALL = 1 - REPLY = 2 - EXCEPTION = 3 - ONEWAY = 4 + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 class TProcessor(object): - """Base class for procsessor, which works on two streams.""" + """Base class for procsessor, which works on two streams.""" - def process(iprot, oprot): - pass + def process(iprot, oprot): + pass class TException(Exception): - """Base class for all thrift exceptions.""" + """Base class for all thrift exceptions.""" - # BaseException.message is deprecated in Python v[2.6,3.0) - if (2, 6, 0) <= sys.version_info < (3, 0): - def _get_message(self): - return self._message + # BaseException.message is deprecated in Python v[2.6,3.0) + if (2, 6, 0) <= sys.version_info < (3, 0): + def _get_message(self): + return self._message - def _set_message(self, message): - self._message = message - message = property(_get_message, _set_message) + def _set_message(self, message): + self._message = message + message = property(_get_message, _set_message) - def __init__(self, message=None): - Exception.__init__(self, message) - self.message = message + def __init__(self, message=None): + Exception.__init__(self, message) + self.message = message class TApplicationException(TException): - """Application level thrift exceptions.""" - - UNKNOWN = 0 - UNKNOWN_METHOD = 1 - INVALID_MESSAGE_TYPE = 2 - WRONG_METHOD_NAME = 3 - BAD_SEQUENCE_ID = 4 - MISSING_RESULT = 5 - INTERNAL_ERROR = 6 - PROTOCOL_ERROR = 7 - INVALID_TRANSFORM = 8 - INVALID_PROTOCOL = 9 - UNSUPPORTED_CLIENT_TYPE = 10 - - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type - - def __str__(self): - if self.message: - return self.message - elif self.type == self.UNKNOWN_METHOD: - return 'Unknown method' - elif self.type == self.INVALID_MESSAGE_TYPE: - return 'Invalid message type' - elif self.type == self.WRONG_METHOD_NAME: - return 'Wrong method name' - elif self.type == self.BAD_SEQUENCE_ID: - return 'Bad sequence ID' - elif self.type == self.MISSING_RESULT: - return 'Missing result' - elif self.type == self.INTERNAL_ERROR: - return 'Internal error' - elif self.type == self.PROTOCOL_ERROR: - return 'Protocol error' - elif self.type == self.INVALID_TRANSFORM: - return 'Invalid transform' - elif self.type == self.INVALID_PROTOCOL: - return 'Invalid protocol' - elif self.type == self.UNSUPPORTED_CLIENT_TYPE: - return 'Unsupported client type' - else: - return 'Default (unknown) TApplicationException' - - def read(self, iprot): - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.message = iprot.readString() + """Application level thrift exceptions.""" + + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + def __str__(self): + if self.message: + return self.message + elif self.type == self.UNKNOWN_METHOD: + return 'Unknown method' + elif self.type == self.INVALID_MESSAGE_TYPE: + return 'Invalid message type' + elif self.type == self.WRONG_METHOD_NAME: + return 'Wrong method name' + elif self.type == self.BAD_SEQUENCE_ID: + return 'Bad sequence ID' + elif self.type == self.MISSING_RESULT: + return 'Missing result' + elif self.type == self.INTERNAL_ERROR: + return 'Internal error' + elif self.type == self.PROTOCOL_ERROR: + return 'Protocol error' + elif self.type == self.INVALID_TRANSFORM: + return 'Invalid transform' + elif self.type == self.INVALID_PROTOCOL: + return 'Invalid protocol' + elif self.type == self.UNSUPPORTED_CLIENT_TYPE: + return 'Unsupported client type' else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.type = iprot.readI32() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - oprot.writeStructBegin('TApplicationException') - if self.message is not None: - oprot.writeFieldBegin('message', TType.STRING, 1) - oprot.writeString(self.message) - oprot.writeFieldEnd() - if self.type is not None: - oprot.writeFieldBegin('type', TType.I32, 2) - oprot.writeI32(self.type) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() + return 'Default (unknown) TApplicationException' + + def read(self, iprot): + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.message = iprot.readString() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.type = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + oprot.writeStructBegin('TApplicationException') + if self.message is not None: + oprot.writeFieldBegin('message', TType.STRING, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + if self.type is not None: + oprot.writeFieldBegin('type', TType.I32, 2) + oprot.writeI32(self.type) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() class TFrozenDict(dict): - """A dictionary that is "frozen" like a frozenset""" + """A dictionary that is "frozen" like a frozenset""" - def __init__(self, *args, **kwargs): - super(TFrozenDict, self).__init__(*args, **kwargs) - # Sort the items so they will be in a consistent order. - # XOR in the hash of the class so we don't collide with - # the hash of a list of tuples. - self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items()))) + def __init__(self, *args, **kwargs): + super(TFrozenDict, self).__init__(*args, **kwargs) + # Sort the items so they will be in a consistent order. + # XOR in the hash of the class so we don't collide with + # the hash of a list of tuples. + self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items()))) - def __setitem__(self, *args): - raise TypeError("Can't modify frozen TFreezableDict") + def __setitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") - def __delitem__(self, *args): - raise TypeError("Can't modify frozen TFreezableDict") + def __delitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") - def __hash__(self): - return self.__hashval + def __hash__(self): + return self.__hashval diff --git a/lib/py/src/compat.py b/lib/py/src/compat.py index 06f672ae6..42403eae8 100644 --- a/lib/py/src/compat.py +++ b/lib/py/src/compat.py @@ -1,27 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + import sys if sys.version_info[0] == 2: - from cStringIO import StringIO as BufferIO + from cStringIO import StringIO as BufferIO - def binary_to_str(bin_val): - return bin_val + def binary_to_str(bin_val): + return bin_val - def str_to_binary(str_val): - return str_val + def str_to_binary(str_val): + return str_val else: - from io import BytesIO as BufferIO + from io import BytesIO as BufferIO - def binary_to_str(bin_val): - try: - return bin_val.decode('utf8') - except: - return bin_val + def binary_to_str(bin_val): + try: + return bin_val.decode('utf8') + except: + return bin_val - def str_to_binary(str_val): - try: - return bytes(str_val, 'utf8') - except: - return str_val + def str_to_binary(str_val): + try: + return bytes(str_val, 'utf8') + except: + return str_val diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py index d106f4e03..87caf0d16 100644 --- a/lib/py/src/protocol/TBase.py +++ b/lib/py/src/protocol/TBase.py @@ -21,78 +21,79 @@ from thrift.protocol import TBinaryProtocol from thrift.transport import TTransport try: - from thrift.protocol import fastbinary + from thrift.protocol import fastbinary except: - fastbinary = None + fastbinary = None class TBase(object): - __slots__ = () - - def __repr__(self): - L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - for attr in self.__slots__: - my_val = getattr(self, attr) - other_val = getattr(other, attr) - if my_val != other_val: - return False - return True - - def __ne__(self, other): - return not (self == other) - - def read(self, iprot): - if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and - isinstance(iprot.trans, TTransport.CReadableTransport) and - self.thrift_spec is not None and - fastbinary is not None): - fastbinary.decode_binary(self, - iprot.trans, - (self.__class__, self.thrift_spec), - iprot.string_length_limit, - iprot.container_length_limit) - return - iprot.readStruct(self, self.thrift_spec) - - def write(self, oprot): - if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and - self.thrift_spec is not None and - fastbinary is not None): - oprot.trans.write( - fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStruct(self, self.thrift_spec) + __slots__ = () + + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def read(self, iprot): + if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + isinstance(iprot.trans, TTransport.CReadableTransport) and + self.thrift_spec is not None and + fastbinary is not None): + fastbinary.decode_binary(self, + iprot.trans, + (self.__class__, self.thrift_spec), + iprot.string_length_limit, + iprot.container_length_limit) + return + iprot.readStruct(self, self.thrift_spec) + + def write(self, oprot): + if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + self.thrift_spec is not None and + fastbinary is not None): + oprot.trans.write( + fastbinary.encode_binary( + self, (self.__class__, self.thrift_spec))) + return + oprot.writeStruct(self, self.thrift_spec) class TExceptionBase(TBase, Exception): - pass + pass class TFrozenBase(TBase): - def __setitem__(self, *args): - raise TypeError("Can't modify frozen struct") - - def __delitem__(self, *args): - raise TypeError("Can't modify frozen struct") - - def __hash__(self, *args): - return hash(self.__class__) ^ hash(self.__slots__) - - @classmethod - def read(cls, iprot): - if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and - isinstance(iprot.trans, TTransport.CReadableTransport) and - cls.thrift_spec is not None and - fastbinary is not None): - self = cls() - return fastbinary.decode_binary(None, - iprot.trans, - (self.__class__, self.thrift_spec), - iprot.string_length_limit, - iprot.container_length_limit) - return iprot.readStruct(cls, cls.thrift_spec, True) + def __setitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __delitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __hash__(self, *args): + return hash(self.__class__) ^ hash(self.__slots__) + + @classmethod + def read(cls, iprot): + if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + isinstance(iprot.trans, TTransport.CReadableTransport) and + cls.thrift_spec is not None and + fastbinary is not None): + self = cls() + return fastbinary.decode_binary(None, + iprot.trans, + (self.__class__, self.thrift_spec), + iprot.string_length_limit, + iprot.container_length_limit) + return iprot.readStruct(cls, cls.thrift_spec, True) diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py index db4ea3182..7fce12f07 100644 --- a/lib/py/src/protocol/TBinaryProtocol.py +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -22,264 +22,264 @@ from struct import pack, unpack class TBinaryProtocol(TProtocolBase): - """Binary implementation of the Thrift protocol driver.""" + """Binary implementation of the Thrift protocol driver.""" - # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be - # positive, converting this into a long. If we hardcode the int value - # instead it'll stay in 32 bit-land. + # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be + # positive, converting this into a long. If we hardcode the int value + # instead it'll stay in 32 bit-land. - # VERSION_MASK = 0xffff0000 - VERSION_MASK = -65536 + # VERSION_MASK = 0xffff0000 + VERSION_MASK = -65536 - # VERSION_1 = 0x80010000 - VERSION_1 = -2147418112 + # VERSION_1 = 0x80010000 + VERSION_1 = -2147418112 - TYPE_MASK = 0x000000ff + TYPE_MASK = 0x000000ff - def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs): - TProtocolBase.__init__(self, trans) - self.strictRead = strictRead - self.strictWrite = strictWrite - self.string_length_limit = kwargs.get('string_length_limit', None) - self.container_length_limit = kwargs.get('container_length_limit', None) - - def _check_string_length(self, length): - self._check_length(self.string_length_limit, length) - - def _check_container_length(self, length): - self._check_length(self.container_length_limit, length) - - def writeMessageBegin(self, name, type, seqid): - if self.strictWrite: - self.writeI32(TBinaryProtocol.VERSION_1 | type) - self.writeString(name) - self.writeI32(seqid) - else: - self.writeString(name) - self.writeByte(type) - self.writeI32(seqid) - - def writeMessageEnd(self): - pass - - def writeStructBegin(self, name): - pass - - def writeStructEnd(self): - pass - - def writeFieldBegin(self, name, type, id): - self.writeByte(type) - self.writeI16(id) - - def writeFieldEnd(self): - pass - - def writeFieldStop(self): - self.writeByte(TType.STOP) - - def writeMapBegin(self, ktype, vtype, size): - self.writeByte(ktype) - self.writeByte(vtype) - self.writeI32(size) - - def writeMapEnd(self): - pass - - def writeListBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) - - def writeListEnd(self): - pass - - def writeSetBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) - - def writeSetEnd(self): - pass - - def writeBool(self, bool): - if bool: - self.writeByte(1) - else: - self.writeByte(0) - - def writeByte(self, byte): - buff = pack("!b", byte) - self.trans.write(buff) - - def writeI16(self, i16): - buff = pack("!h", i16) - self.trans.write(buff) - - def writeI32(self, i32): - buff = pack("!i", i32) - self.trans.write(buff) - - def writeI64(self, i64): - buff = pack("!q", i64) - self.trans.write(buff) - - def writeDouble(self, dub): - buff = pack("!d", dub) - self.trans.write(buff) - - def writeBinary(self, str): - self.writeI32(len(str)) - self.trans.write(str) - - def readMessageBegin(self): - sz = self.readI32() - if sz < 0: - version = sz & TBinaryProtocol.VERSION_MASK - if version != TBinaryProtocol.VERSION_1: - raise TProtocolException( - type=TProtocolException.BAD_VERSION, - message='Bad version in readMessageBegin: %d' % (sz)) - type = sz & TBinaryProtocol.TYPE_MASK - name = self.readString() - seqid = self.readI32() - else: - if self.strictRead: - raise TProtocolException(type=TProtocolException.BAD_VERSION, - message='No protocol version header') - name = self.trans.readAll(sz) - type = self.readByte() - seqid = self.readI32() - return (name, type, seqid) - - def readMessageEnd(self): - pass - - def readStructBegin(self): - pass - - def readStructEnd(self): - pass - - def readFieldBegin(self): - type = self.readByte() - if type == TType.STOP: - return (None, type, 0) - id = self.readI16() - return (None, type, id) - - def readFieldEnd(self): - pass - - def readMapBegin(self): - ktype = self.readByte() - vtype = self.readByte() - size = self.readI32() - self._check_container_length(size) - return (ktype, vtype, size) - - def readMapEnd(self): - pass - - def readListBegin(self): - etype = self.readByte() - size = self.readI32() - self._check_container_length(size) - return (etype, size) - - def readListEnd(self): - pass - - def readSetBegin(self): - etype = self.readByte() - size = self.readI32() - self._check_container_length(size) - return (etype, size) - - def readSetEnd(self): - pass - - def readBool(self): - byte = self.readByte() - if byte == 0: - return False - return True - - def readByte(self): - buff = self.trans.readAll(1) - val, = unpack('!b', buff) - return val - - def readI16(self): - buff = self.trans.readAll(2) - val, = unpack('!h', buff) - return val - - def readI32(self): - buff = self.trans.readAll(4) - val, = unpack('!i', buff) - return val - - def readI64(self): - buff = self.trans.readAll(8) - val, = unpack('!q', buff) - return val - - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('!d', buff) - return val - - def readBinary(self): - size = self.readI32() - self._check_string_length(size) - s = self.trans.readAll(size) - return s + def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs): + TProtocolBase.__init__(self, trans) + self.strictRead = strictRead + self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) + + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) + + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) + + def writeMessageBegin(self, name, type, seqid): + if self.strictWrite: + self.writeI32(TBinaryProtocol.VERSION_1 | type) + self.writeString(name) + self.writeI32(seqid) + else: + self.writeString(name) + self.writeByte(type) + self.writeI32(seqid) + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, type, id): + self.writeByte(type) + self.writeI16(id) + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + self.writeByte(TType.STOP) + + def writeMapBegin(self, ktype, vtype, size): + self.writeByte(ktype) + self.writeByte(vtype) + self.writeI32(size) + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeSetEnd(self): + pass + + def writeBool(self, bool): + if bool: + self.writeByte(1) + else: + self.writeByte(0) + + def writeByte(self, byte): + buff = pack("!b", byte) + self.trans.write(buff) + + def writeI16(self, i16): + buff = pack("!h", i16) + self.trans.write(buff) + + def writeI32(self, i32): + buff = pack("!i", i32) + self.trans.write(buff) + + def writeI64(self, i64): + buff = pack("!q", i64) + self.trans.write(buff) + + def writeDouble(self, dub): + buff = pack("!d", dub) + self.trans.write(buff) + + def writeBinary(self, str): + self.writeI32(len(str)) + self.trans.write(str) + + def readMessageBegin(self): + sz = self.readI32() + if sz < 0: + version = sz & TBinaryProtocol.VERSION_MASK + if version != TBinaryProtocol.VERSION_1: + raise TProtocolException( + type=TProtocolException.BAD_VERSION, + message='Bad version in readMessageBegin: %d' % (sz)) + type = sz & TBinaryProtocol.TYPE_MASK + name = self.readString() + seqid = self.readI32() + else: + if self.strictRead: + raise TProtocolException(type=TProtocolException.BAD_VERSION, + message='No protocol version header') + name = self.trans.readAll(sz) + type = self.readByte() + seqid = self.readI32() + return (name, type, seqid) + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + type = self.readByte() + if type == TType.STOP: + return (None, type, 0) + id = self.readI16() + return (None, type, id) + + def readFieldEnd(self): + pass + + def readMapBegin(self): + ktype = self.readByte() + vtype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (ktype, vtype, size) + + def readMapEnd(self): + pass + + def readListBegin(self): + etype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (etype, size) + + def readListEnd(self): + pass + + def readSetBegin(self): + etype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (etype, size) + + def readSetEnd(self): + pass + + def readBool(self): + byte = self.readByte() + if byte == 0: + return False + return True + + def readByte(self): + buff = self.trans.readAll(1) + val, = unpack('!b', buff) + return val + + def readI16(self): + buff = self.trans.readAll(2) + val, = unpack('!h', buff) + return val + + def readI32(self): + buff = self.trans.readAll(4) + val, = unpack('!i', buff) + return val + + def readI64(self): + buff = self.trans.readAll(8) + val, = unpack('!q', buff) + return val + + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('!d', buff) + return val + + def readBinary(self): + size = self.readI32() + self._check_string_length(size) + s = self.trans.readAll(size) + return s class TBinaryProtocolFactory(object): - def __init__(self, strictRead=False, strictWrite=True, **kwargs): - self.strictRead = strictRead - self.strictWrite = strictWrite - self.string_length_limit = kwargs.get('string_length_limit', None) - self.container_length_limit = kwargs.get('container_length_limit', None) + def __init__(self, strictRead=False, strictWrite=True, **kwargs): + self.strictRead = strictRead + self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) - def getProtocol(self, trans): - prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite, - string_length_limit=self.string_length_limit, - container_length_limit=self.container_length_limit) - return prot + def getProtocol(self, trans): + prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit) + return prot class TBinaryProtocolAccelerated(TBinaryProtocol): - """C-Accelerated version of TBinaryProtocol. - - This class does not override any of TBinaryProtocol's methods, - but the generated code recognizes it directly and will call into - our C module to do the encoding, bypassing this object entirely. - We inherit from TBinaryProtocol so that the normal TBinaryProtocol - encoding can happen if the fastbinary module doesn't work for some - reason. (TODO(dreiss): Make this happen sanely in more cases.) - - In order to take advantage of the C module, just use - TBinaryProtocolAccelerated instead of TBinaryProtocol. - - NOTE: This code was contributed by an external developer. - The internal Thrift team has reviewed and tested it, - but we cannot guarantee that it is production-ready. - Please feel free to report bugs and/or success stories - to the public mailing list. - """ - pass + """C-Accelerated version of TBinaryProtocol. + + This class does not override any of TBinaryProtocol's methods, + but the generated code recognizes it directly and will call into + our C module to do the encoding, bypassing this object entirely. + We inherit from TBinaryProtocol so that the normal TBinaryProtocol + encoding can happen if the fastbinary module doesn't work for some + reason. (TODO(dreiss): Make this happen sanely in more cases.) + + In order to take advantage of the C module, just use + TBinaryProtocolAccelerated instead of TBinaryProtocol. + + NOTE: This code was contributed by an external developer. + The internal Thrift team has reviewed and tested it, + but we cannot guarantee that it is production-ready. + Please feel free to report bugs and/or success stories + to the public mailing list. + """ + pass class TBinaryProtocolAcceleratedFactory(object): - def __init__(self, - string_length_limit=None, - container_length_limit=None): - self.string_length_limit = string_length_limit - self.container_length_limit = container_length_limit - - def getProtocol(self, trans): - return TBinaryProtocolAccelerated( - trans, - string_length_limit=self.string_length_limit, - container_length_limit=self.container_length_limit) + def __init__(self, + string_length_limit=None, + container_length_limit=None): + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + + def getProtocol(self, trans): + return TBinaryProtocolAccelerated( + trans, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit) diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py index 3d9c0e6e3..8d3db1a9d 100644 --- a/lib/py/src/protocol/TCompactProtocol.py +++ b/lib/py/src/protocol/TCompactProtocol.py @@ -36,390 +36,391 @@ BOOL_READ = 8 def make_helper(v_from, container): - def helper(func): - def nested(self, *args, **kwargs): - assert self.state in (v_from, container), (self.state, v_from, container) - return func(self, *args, **kwargs) - return nested - return helper + def helper(func): + def nested(self, *args, **kwargs): + assert self.state in (v_from, container), (self.state, v_from, container) + return func(self, *args, **kwargs) + return nested + return helper writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) reader = make_helper(VALUE_READ, CONTAINER_READ) def makeZigZag(n, bits): - checkIntegerLimits(n, bits) - return (n << 1) ^ (n >> (bits - 1)) + checkIntegerLimits(n, bits) + return (n << 1) ^ (n >> (bits - 1)) def fromZigZag(n): - return (n >> 1) ^ -(n & 1) + return (n >> 1) ^ -(n & 1) def writeVarint(trans, n): - out = bytearray() - while True: - if n & ~0x7f == 0: - out.append(n) - break - else: - out.append((n & 0xff) | 0x80) - n = n >> 7 - trans.write(bytes(out)) + out = bytearray() + while True: + if n & ~0x7f == 0: + out.append(n) + break + else: + out.append((n & 0xff) | 0x80) + n = n >> 7 + trans.write(bytes(out)) def readVarint(trans): - result = 0 - shift = 0 - while True: - x = trans.readAll(1) - byte = ord(x) - result |= (byte & 0x7f) << shift - if byte >> 7 == 0: - return result - shift += 7 + result = 0 + shift = 0 + while True: + x = trans.readAll(1) + byte = ord(x) + result |= (byte & 0x7f) << shift + if byte >> 7 == 0: + return result + shift += 7 class CompactType(object): - STOP = 0x00 - TRUE = 0x01 - FALSE = 0x02 - BYTE = 0x03 - I16 = 0x04 - I32 = 0x05 - I64 = 0x06 - DOUBLE = 0x07 - BINARY = 0x08 - LIST = 0x09 - SET = 0x0A - MAP = 0x0B - STRUCT = 0x0C - -CTYPES = {TType.STOP: CompactType.STOP, - TType.BOOL: CompactType.TRUE, # used for collection - TType.BYTE: CompactType.BYTE, - TType.I16: CompactType.I16, - TType.I32: CompactType.I32, - TType.I64: CompactType.I64, - TType.DOUBLE: CompactType.DOUBLE, - TType.STRING: CompactType.BINARY, - TType.STRUCT: CompactType.STRUCT, - TType.LIST: CompactType.LIST, - TType.SET: CompactType.SET, - TType.MAP: CompactType.MAP - } + STOP = 0x00 + TRUE = 0x01 + FALSE = 0x02 + BYTE = 0x03 + I16 = 0x04 + I32 = 0x05 + I64 = 0x06 + DOUBLE = 0x07 + BINARY = 0x08 + LIST = 0x09 + SET = 0x0A + MAP = 0x0B + STRUCT = 0x0C + +CTYPES = { + TType.STOP: CompactType.STOP, + TType.BOOL: CompactType.TRUE, # used for collection + TType.BYTE: CompactType.BYTE, + TType.I16: CompactType.I16, + TType.I32: CompactType.I32, + TType.I64: CompactType.I64, + TType.DOUBLE: CompactType.DOUBLE, + TType.STRING: CompactType.BINARY, + TType.STRUCT: CompactType.STRUCT, + TType.LIST: CompactType.LIST, + TType.SET: CompactType.SET, + TType.MAP: CompactType.MAP, +} TTYPES = {} for k, v in CTYPES.items(): - TTYPES[v] = k + TTYPES[v] = k TTYPES[CompactType.FALSE] = TType.BOOL del k del v class TCompactProtocol(TProtocolBase): - """Compact implementation of the Thrift protocol driver.""" - - PROTOCOL_ID = 0x82 - VERSION = 1 - VERSION_MASK = 0x1f - TYPE_MASK = 0xe0 - TYPE_BITS = 0x07 - TYPE_SHIFT_AMOUNT = 5 - - def __init__(self, trans, - string_length_limit=None, - container_length_limit=None): - TProtocolBase.__init__(self, trans) - self.state = CLEAR - self.__last_fid = 0 - self.__bool_fid = None - self.__bool_value = None - self.__structs = [] - self.__containers = [] - self.string_length_limit = string_length_limit - self.container_length_limit = container_length_limit - - def _check_string_length(self, length): - self._check_length(self.string_length_limit, length) - - def _check_container_length(self, length): - self._check_length(self.container_length_limit, length) - - def __writeVarint(self, n): - writeVarint(self.trans, n) - - def writeMessageBegin(self, name, type, seqid): - assert self.state == CLEAR - self.__writeUByte(self.PROTOCOL_ID) - self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) - self.__writeVarint(seqid) - self.__writeBinary(str_to_binary(name)) - self.state = VALUE_WRITE - - def writeMessageEnd(self): - assert self.state == VALUE_WRITE - self.state = CLEAR - - def writeStructBegin(self, name): - assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_WRITE - self.__last_fid = 0 - - def writeStructEnd(self): - assert self.state == FIELD_WRITE - self.state, self.__last_fid = self.__structs.pop() - - def writeFieldStop(self): - self.__writeByte(0) - - def __writeFieldHeader(self, type, fid): - delta = fid - self.__last_fid - if 0 < delta <= 15: - self.__writeUByte(delta << 4 | type) - else: - self.__writeByte(type) - self.__writeI16(fid) - self.__last_fid = fid - - def writeFieldBegin(self, name, type, fid): - assert self.state == FIELD_WRITE, self.state - if type == TType.BOOL: - self.state = BOOL_WRITE - self.__bool_fid = fid - else: - self.state = VALUE_WRITE - self.__writeFieldHeader(CTYPES[type], fid) - - def writeFieldEnd(self): - assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state - self.state = FIELD_WRITE - - def __writeUByte(self, byte): - self.trans.write(pack('!B', byte)) - - def __writeByte(self, byte): - self.trans.write(pack('!b', byte)) - - def __writeI16(self, i16): - self.__writeVarint(makeZigZag(i16, 16)) - - def __writeSize(self, i32): - self.__writeVarint(i32) - - def writeCollectionBegin(self, etype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size <= 14: - self.__writeUByte(size << 4 | CTYPES[etype]) - else: - self.__writeUByte(0xf0 | CTYPES[etype]) - self.__writeSize(size) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE - writeSetBegin = writeCollectionBegin - writeListBegin = writeCollectionBegin - - def writeMapBegin(self, ktype, vtype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size == 0: - self.__writeByte(0) - else: - self.__writeSize(size) - self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE - - def writeCollectionEnd(self): - assert self.state == CONTAINER_WRITE, self.state - self.state = self.__containers.pop() - writeMapEnd = writeCollectionEnd - writeSetEnd = writeCollectionEnd - writeListEnd = writeCollectionEnd - - def writeBool(self, bool): - if self.state == BOOL_WRITE: - if bool: - ctype = CompactType.TRUE - else: - ctype = CompactType.FALSE - self.__writeFieldHeader(ctype, self.__bool_fid) - elif self.state == CONTAINER_WRITE: - if bool: - self.__writeByte(CompactType.TRUE) - else: - self.__writeByte(CompactType.FALSE) - else: - raise AssertionError("Invalid state in compact protocol") - - writeByte = writer(__writeByte) - writeI16 = writer(__writeI16) - - @writer - def writeI32(self, i32): - self.__writeVarint(makeZigZag(i32, 32)) - - @writer - def writeI64(self, i64): - self.__writeVarint(makeZigZag(i64, 64)) - - @writer - def writeDouble(self, dub): - self.trans.write(pack('<d', dub)) - - def __writeBinary(self, s): - self.__writeSize(len(s)) - self.trans.write(s) - writeBinary = writer(__writeBinary) - - def readFieldBegin(self): - assert self.state == FIELD_READ, self.state - type = self.__readUByte() - if type & 0x0f == TType.STOP: - return (None, 0, 0) - delta = type >> 4 - if delta == 0: - fid = self.__readI16() - else: - fid = self.__last_fid + delta - self.__last_fid = fid - type = type & 0x0f - if type == CompactType.TRUE: - self.state = BOOL_READ - self.__bool_value = True - elif type == CompactType.FALSE: - self.state = BOOL_READ - self.__bool_value = False - else: - self.state = VALUE_READ - return (None, self.__getTType(type), fid) - - def readFieldEnd(self): - assert self.state in (VALUE_READ, BOOL_READ), self.state - self.state = FIELD_READ - - def __readUByte(self): - result, = unpack('!B', self.trans.readAll(1)) - return result - - def __readByte(self): - result, = unpack('!b', self.trans.readAll(1)) - return result - - def __readVarint(self): - return readVarint(self.trans) - - def __readZigZag(self): - return fromZigZag(self.__readVarint()) - - def __readSize(self): - result = self.__readVarint() - if result < 0: - raise TProtocolException("Length < 0") - return result - - def readMessageBegin(self): - assert self.state == CLEAR - proto_id = self.__readUByte() - if proto_id != self.PROTOCOL_ID: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad protocol id in the message: %d' % proto_id) - ver_type = self.__readUByte() - type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS - version = ver_type & self.VERSION_MASK - if version != self.VERSION: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad version: %d (expect %d)' % (version, self.VERSION)) - seqid = self.__readVarint() - name = binary_to_str(self.__readBinary()) - return (name, type, seqid) - - def readMessageEnd(self): - assert self.state == CLEAR - assert len(self.__structs) == 0 - - def readStructBegin(self): - assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_READ - self.__last_fid = 0 - - def readStructEnd(self): - assert self.state == FIELD_READ - self.state, self.__last_fid = self.__structs.pop() - - def readCollectionBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size_type = self.__readUByte() - size = size_type >> 4 - type = self.__getTType(size_type) - if size == 15: - size = self.__readSize() - self._check_container_length(size) - self.__containers.append(self.state) - self.state = CONTAINER_READ - return type, size - readSetBegin = readCollectionBegin - readListBegin = readCollectionBegin - - def readMapBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size = self.__readSize() - self._check_container_length(size) - types = 0 - if size > 0: - types = self.__readUByte() - vtype = self.__getTType(types) - ktype = self.__getTType(types >> 4) - self.__containers.append(self.state) - self.state = CONTAINER_READ - return (ktype, vtype, size) - - def readCollectionEnd(self): - assert self.state == CONTAINER_READ, self.state - self.state = self.__containers.pop() - readSetEnd = readCollectionEnd - readListEnd = readCollectionEnd - readMapEnd = readCollectionEnd - - def readBool(self): - if self.state == BOOL_READ: - return self.__bool_value == CompactType.TRUE - elif self.state == CONTAINER_READ: - return self.__readByte() == CompactType.TRUE - else: - raise AssertionError("Invalid state in compact protocol: %d" % - self.state) - - readByte = reader(__readByte) - __readI16 = __readZigZag - readI16 = reader(__readZigZag) - readI32 = reader(__readZigZag) - readI64 = reader(__readZigZag) - - @reader - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('<d', buff) - return val - - def __readBinary(self): - size = self.__readSize() - self._check_string_length(size) - return self.trans.readAll(size) - readBinary = reader(__readBinary) - - def __getTType(self, byte): - return TTYPES[byte & 0x0f] + """Compact implementation of the Thrift protocol driver.""" + + PROTOCOL_ID = 0x82 + VERSION = 1 + VERSION_MASK = 0x1f + TYPE_MASK = 0xe0 + TYPE_BITS = 0x07 + TYPE_SHIFT_AMOUNT = 5 + + def __init__(self, trans, + string_length_limit=None, + container_length_limit=None): + TProtocolBase.__init__(self, trans) + self.state = CLEAR + self.__last_fid = 0 + self.__bool_fid = None + self.__bool_value = None + self.__structs = [] + self.__containers = [] + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) + + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) + + def __writeVarint(self, n): + writeVarint(self.trans, n) + + def writeMessageBegin(self, name, type, seqid): + assert self.state == CLEAR + self.__writeUByte(self.PROTOCOL_ID) + self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) + self.__writeVarint(seqid) + self.__writeBinary(str_to_binary(name)) + self.state = VALUE_WRITE + + def writeMessageEnd(self): + assert self.state == VALUE_WRITE + self.state = CLEAR + + def writeStructBegin(self, name): + assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_WRITE + self.__last_fid = 0 + + def writeStructEnd(self): + assert self.state == FIELD_WRITE + self.state, self.__last_fid = self.__structs.pop() + + def writeFieldStop(self): + self.__writeByte(0) + + def __writeFieldHeader(self, type, fid): + delta = fid - self.__last_fid + if 0 < delta <= 15: + self.__writeUByte(delta << 4 | type) + else: + self.__writeByte(type) + self.__writeI16(fid) + self.__last_fid = fid + + def writeFieldBegin(self, name, type, fid): + assert self.state == FIELD_WRITE, self.state + if type == TType.BOOL: + self.state = BOOL_WRITE + self.__bool_fid = fid + else: + self.state = VALUE_WRITE + self.__writeFieldHeader(CTYPES[type], fid) + + def writeFieldEnd(self): + assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state + self.state = FIELD_WRITE + + def __writeUByte(self, byte): + self.trans.write(pack('!B', byte)) + + def __writeByte(self, byte): + self.trans.write(pack('!b', byte)) + + def __writeI16(self, i16): + self.__writeVarint(makeZigZag(i16, 16)) + + def __writeSize(self, i32): + self.__writeVarint(i32) + + def writeCollectionBegin(self, etype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size <= 14: + self.__writeUByte(size << 4 | CTYPES[etype]) + else: + self.__writeUByte(0xf0 | CTYPES[etype]) + self.__writeSize(size) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE + writeSetBegin = writeCollectionBegin + writeListBegin = writeCollectionBegin + + def writeMapBegin(self, ktype, vtype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size == 0: + self.__writeByte(0) + else: + self.__writeSize(size) + self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE + + def writeCollectionEnd(self): + assert self.state == CONTAINER_WRITE, self.state + self.state = self.__containers.pop() + writeMapEnd = writeCollectionEnd + writeSetEnd = writeCollectionEnd + writeListEnd = writeCollectionEnd + + def writeBool(self, bool): + if self.state == BOOL_WRITE: + if bool: + ctype = CompactType.TRUE + else: + ctype = CompactType.FALSE + self.__writeFieldHeader(ctype, self.__bool_fid) + elif self.state == CONTAINER_WRITE: + if bool: + self.__writeByte(CompactType.TRUE) + else: + self.__writeByte(CompactType.FALSE) + else: + raise AssertionError("Invalid state in compact protocol") + + writeByte = writer(__writeByte) + writeI16 = writer(__writeI16) + + @writer + def writeI32(self, i32): + self.__writeVarint(makeZigZag(i32, 32)) + + @writer + def writeI64(self, i64): + self.__writeVarint(makeZigZag(i64, 64)) + + @writer + def writeDouble(self, dub): + self.trans.write(pack('<d', dub)) + + def __writeBinary(self, s): + self.__writeSize(len(s)) + self.trans.write(s) + writeBinary = writer(__writeBinary) + + def readFieldBegin(self): + assert self.state == FIELD_READ, self.state + type = self.__readUByte() + if type & 0x0f == TType.STOP: + return (None, 0, 0) + delta = type >> 4 + if delta == 0: + fid = self.__readI16() + else: + fid = self.__last_fid + delta + self.__last_fid = fid + type = type & 0x0f + if type == CompactType.TRUE: + self.state = BOOL_READ + self.__bool_value = True + elif type == CompactType.FALSE: + self.state = BOOL_READ + self.__bool_value = False + else: + self.state = VALUE_READ + return (None, self.__getTType(type), fid) + + def readFieldEnd(self): + assert self.state in (VALUE_READ, BOOL_READ), self.state + self.state = FIELD_READ + + def __readUByte(self): + result, = unpack('!B', self.trans.readAll(1)) + return result + + def __readByte(self): + result, = unpack('!b', self.trans.readAll(1)) + return result + + def __readVarint(self): + return readVarint(self.trans) + + def __readZigZag(self): + return fromZigZag(self.__readVarint()) + + def __readSize(self): + result = self.__readVarint() + if result < 0: + raise TProtocolException("Length < 0") + return result + + def readMessageBegin(self): + assert self.state == CLEAR + proto_id = self.__readUByte() + if proto_id != self.PROTOCOL_ID: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad protocol id in the message: %d' % proto_id) + ver_type = self.__readUByte() + type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS + version = ver_type & self.VERSION_MASK + if version != self.VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad version: %d (expect %d)' % (version, self.VERSION)) + seqid = self.__readVarint() + name = binary_to_str(self.__readBinary()) + return (name, type, seqid) + + def readMessageEnd(self): + assert self.state == CLEAR + assert len(self.__structs) == 0 + + def readStructBegin(self): + assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_READ + self.__last_fid = 0 + + def readStructEnd(self): + assert self.state == FIELD_READ + self.state, self.__last_fid = self.__structs.pop() + + def readCollectionBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size_type = self.__readUByte() + size = size_type >> 4 + type = self.__getTType(size_type) + if size == 15: + size = self.__readSize() + self._check_container_length(size) + self.__containers.append(self.state) + self.state = CONTAINER_READ + return type, size + readSetBegin = readCollectionBegin + readListBegin = readCollectionBegin + + def readMapBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size = self.__readSize() + self._check_container_length(size) + types = 0 + if size > 0: + types = self.__readUByte() + vtype = self.__getTType(types) + ktype = self.__getTType(types >> 4) + self.__containers.append(self.state) + self.state = CONTAINER_READ + return (ktype, vtype, size) + + def readCollectionEnd(self): + assert self.state == CONTAINER_READ, self.state + self.state = self.__containers.pop() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + readMapEnd = readCollectionEnd + + def readBool(self): + if self.state == BOOL_READ: + return self.__bool_value == CompactType.TRUE + elif self.state == CONTAINER_READ: + return self.__readByte() == CompactType.TRUE + else: + raise AssertionError("Invalid state in compact protocol: %d" % + self.state) + + readByte = reader(__readByte) + __readI16 = __readZigZag + readI16 = reader(__readZigZag) + readI32 = reader(__readZigZag) + readI64 = reader(__readZigZag) + + @reader + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('<d', buff) + return val + + def __readBinary(self): + size = self.__readSize() + self._check_string_length(size) + return self.trans.readAll(size) + readBinary = reader(__readBinary) + + def __getTType(self, byte): + return TTYPES[byte & 0x0f] class TCompactProtocolFactory(object): - def __init__(self, - string_length_limit=None, - container_length_limit=None): - self.string_length_limit = string_length_limit - self.container_length_limit = container_length_limit - - def getProtocol(self, trans): - return TCompactProtocol(trans, - self.string_length_limit, - self.container_length_limit) + def __init__(self, + string_length_limit=None, + container_length_limit=None): + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + + def getProtocol(self, trans): + return TCompactProtocol(trans, + self.string_length_limit, + self.container_length_limit) diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py index f9e65fbf2..db2099a34 100644 --- a/lib/py/src/protocol/TJSONProtocol.py +++ b/lib/py/src/protocol/TJSONProtocol.py @@ -17,7 +17,8 @@ # under the License. # -from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits +from .TProtocol import (TType, TProtocolBase, TProtocolException, + checkIntegerLimits) import base64 import math import sys @@ -45,14 +46,14 @@ ZERO = b'0' ESCSEQ0 = ord('\\') ESCSEQ1 = ord('u') ESCAPE_CHAR_VALS = { - '"': '\\"', - '\\': '\\\\', - '\b': '\\b', - '\f': '\\f', - '\n': '\\n', - '\r': '\\r', - '\t': '\\t', - # '/': '\\/', + '"': '\\"', + '\\': '\\\\', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', + # '/': '\\/', } ESCAPE_CHARS = { b'"': '"', @@ -66,519 +67,527 @@ ESCAPE_CHARS = { } NUMERIC_CHAR = b'+-.0123456789Ee' -CTYPES = {TType.BOOL: 'tf', - TType.BYTE: 'i8', - TType.I16: 'i16', - TType.I32: 'i32', - TType.I64: 'i64', - TType.DOUBLE: 'dbl', - TType.STRING: 'str', - TType.STRUCT: 'rec', - TType.LIST: 'lst', - TType.SET: 'set', - TType.MAP: 'map'} +CTYPES = { + TType.BOOL: 'tf', + TType.BYTE: 'i8', + TType.I16: 'i16', + TType.I32: 'i32', + TType.I64: 'i64', + TType.DOUBLE: 'dbl', + TType.STRING: 'str', + TType.STRUCT: 'rec', + TType.LIST: 'lst', + TType.SET: 'set', + TType.MAP: 'map', +} JTYPES = {} for key in CTYPES.keys(): - JTYPES[CTYPES[key]] = key + JTYPES[CTYPES[key]] = key class JSONBaseContext(object): - def __init__(self, protocol): - self.protocol = protocol - self.first = True + def __init__(self, protocol): + self.protocol = protocol + self.first = True - def doIO(self, function): - pass + def doIO(self, function): + pass - def write(self): - pass + def write(self): + pass - def read(self): - pass + def read(self): + pass - def escapeNum(self): - return False + def escapeNum(self): + return False - def __str__(self): - return self.__class__.__name__ + def __str__(self): + return self.__class__.__name__ class JSONListContext(JSONBaseContext): - def doIO(self, function): - if self.first is True: - self.first = False - else: - function(COMMA) + def doIO(self, function): + if self.first is True: + self.first = False + else: + function(COMMA) - def write(self): - self.doIO(self.protocol.trans.write) + def write(self): + self.doIO(self.protocol.trans.write) - def read(self): - self.doIO(self.protocol.readJSONSyntaxChar) + def read(self): + self.doIO(self.protocol.readJSONSyntaxChar) class JSONPairContext(JSONBaseContext): - def __init__(self, protocol): - super(JSONPairContext, self).__init__(protocol) - self.colon = True + def __init__(self, protocol): + super(JSONPairContext, self).__init__(protocol) + self.colon = True - def doIO(self, function): - if self.first: - self.first = False - self.colon = True - else: - function(COLON if self.colon else COMMA) - self.colon = not self.colon + def doIO(self, function): + if self.first: + self.first = False + self.colon = True + else: + function(COLON if self.colon else COMMA) + self.colon = not self.colon - def write(self): - self.doIO(self.protocol.trans.write) + def write(self): + self.doIO(self.protocol.trans.write) - def read(self): - self.doIO(self.protocol.readJSONSyntaxChar) + def read(self): + self.doIO(self.protocol.readJSONSyntaxChar) - def escapeNum(self): - return self.colon + def escapeNum(self): + return self.colon - def __str__(self): - return '%s, colon=%s' % (self.__class__.__name__, self.colon) + def __str__(self): + return '%s, colon=%s' % (self.__class__.__name__, self.colon) class LookaheadReader(): - hasData = False - data = '' + hasData = False + data = '' - def __init__(self, protocol): - self.protocol = protocol + def __init__(self, protocol): + self.protocol = protocol - def read(self): - if self.hasData is True: - self.hasData = False - else: - self.data = self.protocol.trans.read(1) - return self.data + def read(self): + if self.hasData is True: + self.hasData = False + else: + self.data = self.protocol.trans.read(1) + return self.data - def peek(self): - if self.hasData is False: - self.data = self.protocol.trans.read(1) - self.hasData = True - return self.data + def peek(self): + if self.hasData is False: + self.data = self.protocol.trans.read(1) + self.hasData = True + return self.data class TJSONProtocolBase(TProtocolBase): - def __init__(self, trans): - TProtocolBase.__init__(self, trans) - self.resetWriteContext() - self.resetReadContext() - - # We don't have length limit implementation for JSON protocols - @property - def string_length_limit(senf): - return None - - @property - def container_length_limit(senf): - return None - - def resetWriteContext(self): - self.context = JSONBaseContext(self) - self.contextStack = [self.context] - - def resetReadContext(self): - self.resetWriteContext() - self.reader = LookaheadReader(self) - - def pushContext(self, ctx): - self.contextStack.append(ctx) - self.context = ctx - - def popContext(self): - self.contextStack.pop() - if self.contextStack: - self.context = self.contextStack[-1] - else: - self.context = JSONBaseContext(self) - - def writeJSONString(self, string): - self.context.write() - json_str = ['"'] - for s in string: - escaped = ESCAPE_CHAR_VALS.get(s, s) - json_str.append(escaped) - json_str.append('"') - self.trans.write(str_to_binary(''.join(json_str))) - - def writeJSONNumber(self, number, formatter='{0}'): - self.context.write() - jsNumber = str(formatter.format(number)).encode('ascii') - if self.context.escapeNum(): - self.trans.write(QUOTE) - self.trans.write(jsNumber) - self.trans.write(QUOTE) - else: - self.trans.write(jsNumber) - - def writeJSONBase64(self, binary): - self.context.write() - self.trans.write(QUOTE) - self.trans.write(base64.b64encode(binary)) - self.trans.write(QUOTE) - - def writeJSONObjectStart(self): - self.context.write() - self.trans.write(LBRACE) - self.pushContext(JSONPairContext(self)) - - def writeJSONObjectEnd(self): - self.popContext() - self.trans.write(RBRACE) - - def writeJSONArrayStart(self): - self.context.write() - self.trans.write(LBRACKET) - self.pushContext(JSONListContext(self)) - - def writeJSONArrayEnd(self): - self.popContext() - self.trans.write(RBRACKET) - - def readJSONSyntaxChar(self, character): - current = self.reader.read() - if character != current: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Unexpected character: %s" % current) - - def _isHighSurrogate(self, codeunit): - return codeunit >= 0xd800 and codeunit <= 0xdbff - - def _isLowSurrogate(self, codeunit): - return codeunit >= 0xdc00 and codeunit <= 0xdfff - - def _toChar(self, high, low=None): - if not low: - if sys.version_info[0] == 2: - return ("\\u%04x" % high).decode('unicode-escape').encode('utf-8') - else: - return chr(high) - else: - codepoint = (1 << 16) + ((high & 0x3ff) << 10) - codepoint += low & 0x3ff - if sys.version_info[0] == 2: - s = "\\U%08x" % codepoint - return s.decode('unicode-escape').encode('utf-8') - else: - return chr(codepoint) - - def readJSONString(self, skipContext): - highSurrogate = None - string = [] - if skipContext is False: - self.context.read() - self.readJSONSyntaxChar(QUOTE) - while True: - character = self.reader.read() - if character == QUOTE: - break - if ord(character) == ESCSEQ0: - character = self.reader.read() - if ord(character) == ESCSEQ1: - character = self.trans.read(4).decode('ascii') - codeunit = int(character, 16) - if self._isHighSurrogate(codeunit): - if highSurrogate: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Expected low surrogate char") - highSurrogate = codeunit - continue - elif self._isLowSurrogate(codeunit): - if not highSurrogate: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Expected high surrogate char") - character = self._toChar(highSurrogate, codeunit) - highSurrogate = None - else: - character = self._toChar(codeunit) + def __init__(self, trans): + TProtocolBase.__init__(self, trans) + self.resetWriteContext() + self.resetReadContext() + + # We don't have length limit implementation for JSON protocols + @property + def string_length_limit(senf): + return None + + @property + def container_length_limit(senf): + return None + + def resetWriteContext(self): + self.context = JSONBaseContext(self) + self.contextStack = [self.context] + + def resetReadContext(self): + self.resetWriteContext() + self.reader = LookaheadReader(self) + + def pushContext(self, ctx): + self.contextStack.append(ctx) + self.context = ctx + + def popContext(self): + self.contextStack.pop() + if self.contextStack: + self.context = self.contextStack[-1] + else: + self.context = JSONBaseContext(self) + + def writeJSONString(self, string): + self.context.write() + json_str = ['"'] + for s in string: + escaped = ESCAPE_CHAR_VALS.get(s, s) + json_str.append(escaped) + json_str.append('"') + self.trans.write(str_to_binary(''.join(json_str))) + + def writeJSONNumber(self, number, formatter='{0}'): + self.context.write() + jsNumber = str(formatter.format(number)).encode('ascii') + if self.context.escapeNum(): + self.trans.write(QUOTE) + self.trans.write(jsNumber) + self.trans.write(QUOTE) else: - if character not in ESCAPE_CHARS: + self.trans.write(jsNumber) + + def writeJSONBase64(self, binary): + self.context.write() + self.trans.write(QUOTE) + self.trans.write(base64.b64encode(binary)) + self.trans.write(QUOTE) + + def writeJSONObjectStart(self): + self.context.write() + self.trans.write(LBRACE) + self.pushContext(JSONPairContext(self)) + + def writeJSONObjectEnd(self): + self.popContext() + self.trans.write(RBRACE) + + def writeJSONArrayStart(self): + self.context.write() + self.trans.write(LBRACKET) + self.pushContext(JSONListContext(self)) + + def writeJSONArrayEnd(self): + self.popContext() + self.trans.write(RBRACKET) + + def readJSONSyntaxChar(self, character): + current = self.reader.read() + if character != current: raise TProtocolException(TProtocolException.INVALID_DATA, - "Expected control char") - character = ESCAPE_CHARS[character] - elif character in ESCAPE_CHAR_VALS: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Unescaped control char") - elif sys.version_info[0] > 2: - utf8_bytes = bytearray([ord(character)]) - while ord(self.reader.peek()) >= 0x80: - utf8_bytes.append(ord(self.reader.read())) - character = utf8_bytes.decode('utf8') - string.append(character) - - if highSurrogate: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Expected low surrogate char") - return ''.join(string) - - def isJSONNumeric(self, character): - return (True if NUMERIC_CHAR.find(character) != - 1 else False) - - def readJSONQuotes(self): - if (self.context.escapeNum()): - self.readJSONSyntaxChar(QUOTE) - - def readJSONNumericChars(self): - numeric = [] - while True: - character = self.reader.peek() - if self.isJSONNumeric(character) is False: - break - numeric.append(self.reader.read()) - return b''.join(numeric).decode('ascii') - - def readJSONInteger(self): - self.context.read() - self.readJSONQuotes() - numeric = self.readJSONNumericChars() - self.readJSONQuotes() - try: - return int(numeric) - except ValueError: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Bad data encounted in numeric data") - - def readJSONDouble(self): - self.context.read() - if self.reader.peek() == QUOTE: - string = self.readJSONString(True) - try: - double = float(string) - if (self.context.escapeNum is False and - not math.isinf(double) and - not math.isnan(double)): - raise TProtocolException(TProtocolException.INVALID_DATA, - "Numeric data unexpectedly quoted") - return double - except ValueError: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Bad data encounted in numeric data") - else: - if self.context.escapeNum() is True: + "Unexpected character: %s" % current) + + def _isHighSurrogate(self, codeunit): + return codeunit >= 0xd800 and codeunit <= 0xdbff + + def _isLowSurrogate(self, codeunit): + return codeunit >= 0xdc00 and codeunit <= 0xdfff + + def _toChar(self, high, low=None): + if not low: + if sys.version_info[0] == 2: + return ("\\u%04x" % high).decode('unicode-escape') \ + .encode('utf-8') + else: + return chr(high) + else: + codepoint = (1 << 16) + ((high & 0x3ff) << 10) + codepoint += low & 0x3ff + if sys.version_info[0] == 2: + s = "\\U%08x" % codepoint + return s.decode('unicode-escape').encode('utf-8') + else: + return chr(codepoint) + + def readJSONString(self, skipContext): + highSurrogate = None + string = [] + if skipContext is False: + self.context.read() self.readJSONSyntaxChar(QUOTE) - try: - return float(self.readJSONNumericChars()) - except ValueError: - raise TProtocolException(TProtocolException.INVALID_DATA, - "Bad data encounted in numeric data") - - def readJSONBase64(self): - string = self.readJSONString(False) - size = len(string) - m = size % 4 - # Force padding since b64encode method does not allow it - if m != 0: - for i in range(4 - m): - string += '=' - return base64.b64decode(string) - - def readJSONObjectStart(self): - self.context.read() - self.readJSONSyntaxChar(LBRACE) - self.pushContext(JSONPairContext(self)) - - def readJSONObjectEnd(self): - self.readJSONSyntaxChar(RBRACE) - self.popContext() - - def readJSONArrayStart(self): - self.context.read() - self.readJSONSyntaxChar(LBRACKET) - self.pushContext(JSONListContext(self)) - - def readJSONArrayEnd(self): - self.readJSONSyntaxChar(RBRACKET) - self.popContext() + while True: + character = self.reader.read() + if character == QUOTE: + break + if ord(character) == ESCSEQ0: + character = self.reader.read() + if ord(character) == ESCSEQ1: + character = self.trans.read(4).decode('ascii') + codeunit = int(character, 16) + if self._isHighSurrogate(codeunit): + if highSurrogate: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected low surrogate char") + highSurrogate = codeunit + continue + elif self._isLowSurrogate(codeunit): + if not highSurrogate: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected high surrogate char") + character = self._toChar(highSurrogate, codeunit) + highSurrogate = None + else: + character = self._toChar(codeunit) + else: + if character not in ESCAPE_CHARS: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected control char") + character = ESCAPE_CHARS[character] + elif character in ESCAPE_CHAR_VALS: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Unescaped control char") + elif sys.version_info[0] > 2: + utf8_bytes = bytearray([ord(character)]) + while ord(self.reader.peek()) >= 0x80: + utf8_bytes.append(ord(self.reader.read())) + character = utf8_bytes.decode('utf8') + string.append(character) + + if highSurrogate: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Expected low surrogate char") + return ''.join(string) + + def isJSONNumeric(self, character): + return (True if NUMERIC_CHAR.find(character) != - 1 else False) + + def readJSONQuotes(self): + if (self.context.escapeNum()): + self.readJSONSyntaxChar(QUOTE) + + def readJSONNumericChars(self): + numeric = [] + while True: + character = self.reader.peek() + if self.isJSONNumeric(character) is False: + break + numeric.append(self.reader.read()) + return b''.join(numeric).decode('ascii') + + def readJSONInteger(self): + self.context.read() + self.readJSONQuotes() + numeric = self.readJSONNumericChars() + self.readJSONQuotes() + try: + return int(numeric) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONDouble(self): + self.context.read() + if self.reader.peek() == QUOTE: + string = self.readJSONString(True) + try: + double = float(string) + if (self.context.escapeNum is False and + not math.isinf(double) and + not math.isnan(double)): + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Numeric data unexpectedly quoted") + return double + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + else: + if self.context.escapeNum() is True: + self.readJSONSyntaxChar(QUOTE) + try: + return float(self.readJSONNumericChars()) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONBase64(self): + string = self.readJSONString(False) + size = len(string) + m = size % 4 + # Force padding since b64encode method does not allow it + if m != 0: + for i in range(4 - m): + string += '=' + return base64.b64decode(string) + + def readJSONObjectStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACE) + self.pushContext(JSONPairContext(self)) + + def readJSONObjectEnd(self): + self.readJSONSyntaxChar(RBRACE) + self.popContext() + + def readJSONArrayStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACKET) + self.pushContext(JSONListContext(self)) + + def readJSONArrayEnd(self): + self.readJSONSyntaxChar(RBRACKET) + self.popContext() class TJSONProtocol(TJSONProtocolBase): - def readMessageBegin(self): - self.resetReadContext() - self.readJSONArrayStart() - if self.readJSONInteger() != VERSION: - raise TProtocolException(TProtocolException.BAD_VERSION, - "Message contained bad version.") - name = self.readJSONString(False) - typen = self.readJSONInteger() - seqid = self.readJSONInteger() - return (name, typen, seqid) - - def readMessageEnd(self): - self.readJSONArrayEnd() - - def readStructBegin(self): - self.readJSONObjectStart() - - def readStructEnd(self): - self.readJSONObjectEnd() - - def readFieldBegin(self): - character = self.reader.peek() - ttype = 0 - id = 0 - if character == RBRACE: - ttype = TType.STOP - else: - id = self.readJSONInteger() - self.readJSONObjectStart() - ttype = JTYPES[self.readJSONString(False)] - return (None, ttype, id) - - def readFieldEnd(self): - self.readJSONObjectEnd() - - def readMapBegin(self): - self.readJSONArrayStart() - keyType = JTYPES[self.readJSONString(False)] - valueType = JTYPES[self.readJSONString(False)] - size = self.readJSONInteger() - self.readJSONObjectStart() - return (keyType, valueType, size) - - def readMapEnd(self): - self.readJSONObjectEnd() - self.readJSONArrayEnd() - - def readCollectionBegin(self): - self.readJSONArrayStart() - elemType = JTYPES[self.readJSONString(False)] - size = self.readJSONInteger() - return (elemType, size) - readListBegin = readCollectionBegin - readSetBegin = readCollectionBegin - - def readCollectionEnd(self): - self.readJSONArrayEnd() - readSetEnd = readCollectionEnd - readListEnd = readCollectionEnd - - def readBool(self): - return (False if self.readJSONInteger() == 0 else True) - - def readNumber(self): - return self.readJSONInteger() - readByte = readNumber - readI16 = readNumber - readI32 = readNumber - readI64 = readNumber - - def readDouble(self): - return self.readJSONDouble() - - def readString(self): - return self.readJSONString(False) - - def readBinary(self): - return self.readJSONBase64() - - def writeMessageBegin(self, name, request_type, seqid): - self.resetWriteContext() - self.writeJSONArrayStart() - self.writeJSONNumber(VERSION) - self.writeJSONString(name) - self.writeJSONNumber(request_type) - self.writeJSONNumber(seqid) - - def writeMessageEnd(self): - self.writeJSONArrayEnd() - - def writeStructBegin(self, name): - self.writeJSONObjectStart() - - def writeStructEnd(self): - self.writeJSONObjectEnd() - - def writeFieldBegin(self, name, ttype, id): - self.writeJSONNumber(id) - self.writeJSONObjectStart() - self.writeJSONString(CTYPES[ttype]) - - def writeFieldEnd(self): - self.writeJSONObjectEnd() - - def writeFieldStop(self): - pass - - def writeMapBegin(self, ktype, vtype, size): - self.writeJSONArrayStart() - self.writeJSONString(CTYPES[ktype]) - self.writeJSONString(CTYPES[vtype]) - self.writeJSONNumber(size) - self.writeJSONObjectStart() - - def writeMapEnd(self): - self.writeJSONObjectEnd() - self.writeJSONArrayEnd() - - def writeListBegin(self, etype, size): - self.writeJSONArrayStart() - self.writeJSONString(CTYPES[etype]) - self.writeJSONNumber(size) - - def writeListEnd(self): - self.writeJSONArrayEnd() - - def writeSetBegin(self, etype, size): - self.writeJSONArrayStart() - self.writeJSONString(CTYPES[etype]) - self.writeJSONNumber(size) + def readMessageBegin(self): + self.resetReadContext() + self.readJSONArrayStart() + if self.readJSONInteger() != VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + "Message contained bad version.") + name = self.readJSONString(False) + typen = self.readJSONInteger() + seqid = self.readJSONInteger() + return (name, typen, seqid) + + def readMessageEnd(self): + self.readJSONArrayEnd() + + def readStructBegin(self): + self.readJSONObjectStart() + + def readStructEnd(self): + self.readJSONObjectEnd() + + def readFieldBegin(self): + character = self.reader.peek() + ttype = 0 + id = 0 + if character == RBRACE: + ttype = TType.STOP + else: + id = self.readJSONInteger() + self.readJSONObjectStart() + ttype = JTYPES[self.readJSONString(False)] + return (None, ttype, id) + + def readFieldEnd(self): + self.readJSONObjectEnd() + + def readMapBegin(self): + self.readJSONArrayStart() + keyType = JTYPES[self.readJSONString(False)] + valueType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + self.readJSONObjectStart() + return (keyType, valueType, size) + + def readMapEnd(self): + self.readJSONObjectEnd() + self.readJSONArrayEnd() + + def readCollectionBegin(self): + self.readJSONArrayStart() + elemType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + return (elemType, size) + readListBegin = readCollectionBegin + readSetBegin = readCollectionBegin + + def readCollectionEnd(self): + self.readJSONArrayEnd() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + + def readBool(self): + return (False if self.readJSONInteger() == 0 else True) + + def readNumber(self): + return self.readJSONInteger() + readByte = readNumber + readI16 = readNumber + readI32 = readNumber + readI64 = readNumber + + def readDouble(self): + return self.readJSONDouble() + + def readString(self): + return self.readJSONString(False) + + def readBinary(self): + return self.readJSONBase64() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + self.writeJSONArrayStart() + self.writeJSONNumber(VERSION) + self.writeJSONString(name) + self.writeJSONNumber(request_type) + self.writeJSONNumber(seqid) + + def writeMessageEnd(self): + self.writeJSONArrayEnd() + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() - def writeSetEnd(self): - self.writeJSONArrayEnd() + def writeFieldBegin(self, name, ttype, id): + self.writeJSONNumber(id) + self.writeJSONObjectStart() + self.writeJSONString(CTYPES[ttype]) - def writeBool(self, boolean): - self.writeJSONNumber(1 if boolean is True else 0) - - def writeByte(self, byte): - checkIntegerLimits(byte, 8) - self.writeJSONNumber(byte) + def writeFieldEnd(self): + self.writeJSONObjectEnd() - def writeI16(self, i16): - checkIntegerLimits(i16, 16) - self.writeJSONNumber(i16) + def writeFieldStop(self): + pass - def writeI32(self, i32): - checkIntegerLimits(i32, 32) - self.writeJSONNumber(i32) + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[ktype]) + self.writeJSONString(CTYPES[vtype]) + self.writeJSONNumber(size) + self.writeJSONObjectStart() - def writeI64(self, i64): - checkIntegerLimits(i64, 64) - self.writeJSONNumber(i64) + def writeMapEnd(self): + self.writeJSONObjectEnd() + self.writeJSONArrayEnd() - def writeDouble(self, dbl): - # 17 significant digits should be just enough for any double precision value. - self.writeJSONNumber(dbl, '{0:.17g}') + def writeListBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeListEnd(self): + self.writeJSONArrayEnd() - def writeString(self, string): - self.writeJSONString(string) + def writeSetBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) - def writeBinary(self, binary): - self.writeJSONBase64(binary) + def writeSetEnd(self): + self.writeJSONArrayEnd() + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeByte(self, byte): + checkIntegerLimits(byte, 8) + self.writeJSONNumber(byte) + + def writeI16(self, i16): + checkIntegerLimits(i16, 16) + self.writeJSONNumber(i16) + + def writeI32(self, i32): + checkIntegerLimits(i32, 32) + self.writeJSONNumber(i32) + + def writeI64(self, i64): + checkIntegerLimits(i64, 64) + self.writeJSONNumber(i64) + + def writeDouble(self, dbl): + # 17 significant digits should be just enough for any double precision + # value. + self.writeJSONNumber(dbl, '{0:.17g}') + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) class TJSONProtocolFactory(object): - def getProtocol(self, trans): - return TJSONProtocol(trans) + def getProtocol(self, trans): + return TJSONProtocol(trans) - @property - def string_length_limit(senf): - return None + @property + def string_length_limit(senf): + return None - @property - def container_length_limit(senf): - return None + @property + def container_length_limit(senf): + return None class TSimpleJSONProtocol(TJSONProtocolBase): diff --git a/lib/py/src/protocol/TMultiplexedProtocol.py b/lib/py/src/protocol/TMultiplexedProtocol.py index d25f367b5..309f896d0 100644 --- a/lib/py/src/protocol/TMultiplexedProtocol.py +++ b/lib/py/src/protocol/TMultiplexedProtocol.py @@ -22,18 +22,19 @@ from thrift.protocol import TProtocolDecorator SEPARATOR = ":" + class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator): - def __init__(self, protocol, serviceName): - TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) - self.serviceName = serviceName + def __init__(self, protocol, serviceName): + TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) + self.serviceName = serviceName - def writeMessageBegin(self, name, type, seqid): - if (type == TMessageType.CALL or - type == TMessageType.ONEWAY): - self.protocol.writeMessageBegin( - self.serviceName + SEPARATOR + name, - type, - seqid - ) - else: - self.protocol.writeMessageBegin(name, type, seqid) + def writeMessageBegin(self, name, type, seqid): + if (type == TMessageType.CALL or + type == TMessageType.ONEWAY): + self.protocol.writeMessageBegin( + self.serviceName + SEPARATOR + name, + type, + seqid + ) + else: + self.protocol.writeMessageBegin(name, type, seqid) diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index d9aa2e82b..ed6938bb6 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -28,373 +28,373 @@ from six.moves import zip class TProtocolException(TException): - """Custom Protocol Exception class""" + """Custom Protocol Exception class""" - UNKNOWN = 0 - INVALID_DATA = 1 - NEGATIVE_SIZE = 2 - SIZE_LIMIT = 3 - BAD_VERSION = 4 - NOT_IMPLEMENTED = 5 - DEPTH_LIMIT = 6 + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type class TProtocolBase(object): - """Base class for Thrift protocol driver.""" + """Base class for Thrift protocol driver.""" - def __init__(self, trans): - self.trans = trans + def __init__(self, trans): + self.trans = trans - @staticmethod - def _check_length(limit, length): - if length < 0: - raise TTransportException(TTransportException.NEGATIVE_SIZE, - 'Negative length: %d' % length) - if limit is not None and length > limit: - raise TTransportException(TTransportException.SIZE_LIMIT, - 'Length exceeded max allowed: %d' % limit) + @staticmethod + def _check_length(limit, length): + if length < 0: + raise TTransportException(TTransportException.NEGATIVE_SIZE, + 'Negative length: %d' % length) + if limit is not None and length > limit: + raise TTransportException(TTransportException.SIZE_LIMIT, + 'Length exceeded max allowed: %d' % limit) - def writeMessageBegin(self, name, ttype, seqid): - pass + def writeMessageBegin(self, name, ttype, seqid): + pass - def writeMessageEnd(self): - pass + def writeMessageEnd(self): + pass - def writeStructBegin(self, name): - pass + def writeStructBegin(self, name): + pass - def writeStructEnd(self): - pass + def writeStructEnd(self): + pass - def writeFieldBegin(self, name, ttype, fid): - pass + def writeFieldBegin(self, name, ttype, fid): + pass - def writeFieldEnd(self): - pass + def writeFieldEnd(self): + pass - def writeFieldStop(self): - pass + def writeFieldStop(self): + pass - def writeMapBegin(self, ktype, vtype, size): - pass + def writeMapBegin(self, ktype, vtype, size): + pass - def writeMapEnd(self): - pass + def writeMapEnd(self): + pass - def writeListBegin(self, etype, size): - pass + def writeListBegin(self, etype, size): + pass - def writeListEnd(self): - pass + def writeListEnd(self): + pass - def writeSetBegin(self, etype, size): - pass + def writeSetBegin(self, etype, size): + pass - def writeSetEnd(self): - pass + def writeSetEnd(self): + pass - def writeBool(self, bool_val): - pass + def writeBool(self, bool_val): + pass - def writeByte(self, byte): - pass + def writeByte(self, byte): + pass - def writeI16(self, i16): - pass + def writeI16(self, i16): + pass - def writeI32(self, i32): - pass + def writeI32(self, i32): + pass - def writeI64(self, i64): - pass + def writeI64(self, i64): + pass - def writeDouble(self, dub): - pass + def writeDouble(self, dub): + pass - def writeString(self, str_val): - self.writeBinary(str_to_binary(str_val)) + def writeString(self, str_val): + self.writeBinary(str_to_binary(str_val)) - def writeBinary(self, str_val): - pass + def writeBinary(self, str_val): + pass - def writeUtf8(self, str_val): - self.writeString(str_val.encode('utf8')) + def writeUtf8(self, str_val): + self.writeString(str_val.encode('utf8')) - def readMessageBegin(self): - pass + def readMessageBegin(self): + pass - def readMessageEnd(self): - pass + def readMessageEnd(self): + pass - def readStructBegin(self): - pass + def readStructBegin(self): + pass - def readStructEnd(self): - pass + def readStructEnd(self): + pass - def readFieldBegin(self): - pass + def readFieldBegin(self): + pass - def readFieldEnd(self): - pass + def readFieldEnd(self): + pass - def readMapBegin(self): - pass + def readMapBegin(self): + pass - def readMapEnd(self): - pass + def readMapEnd(self): + pass - def readListBegin(self): - pass + def readListBegin(self): + pass - def readListEnd(self): - pass + def readListEnd(self): + pass - def readSetBegin(self): - pass + def readSetBegin(self): + pass - def readSetEnd(self): - pass + def readSetEnd(self): + pass - def readBool(self): - pass + def readBool(self): + pass - def readByte(self): - pass + def readByte(self): + pass - def readI16(self): - pass + def readI16(self): + pass - def readI32(self): - pass + def readI32(self): + pass - def readI64(self): - pass + def readI64(self): + pass - def readDouble(self): - pass + def readDouble(self): + pass - def readString(self): - return binary_to_str(self.readBinary()) + def readString(self): + return binary_to_str(self.readBinary()) - def readBinary(self): - pass + def readBinary(self): + pass - def readUtf8(self): - return self.readString().decode('utf8') + def readUtf8(self): + return self.readString().decode('utf8') - def skip(self, ttype): - if ttype == TType.STOP: - return - elif ttype == TType.BOOL: - self.readBool() - elif ttype == TType.BYTE: - self.readByte() - elif ttype == TType.I16: - self.readI16() - elif ttype == TType.I32: - self.readI32() - elif ttype == TType.I64: - self.readI64() - elif ttype == TType.DOUBLE: - self.readDouble() - elif ttype == TType.STRING: - self.readString() - elif ttype == TType.STRUCT: - name = self.readStructBegin() - while True: - (name, ttype, id) = self.readFieldBegin() + def skip(self, ttype): if ttype == TType.STOP: - break - self.skip(ttype) - self.readFieldEnd() - self.readStructEnd() - elif ttype == TType.MAP: - (ktype, vtype, size) = self.readMapBegin() - for i in range(size): - self.skip(ktype) - self.skip(vtype) - self.readMapEnd() - elif ttype == TType.SET: - (etype, size) = self.readSetBegin() - for i in range(size): - self.skip(etype) - self.readSetEnd() - elif ttype == TType.LIST: - (etype, size) = self.readListBegin() - for i in range(size): - self.skip(etype) - self.readListEnd() - - # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) - _TTYPE_HANDLERS = ( - (None, None, False), # 0 TType.STOP - (None, None, False), # 1 TType.VOID # TODO: handle void? - ('readBool', 'writeBool', False), # 2 TType.BOOL - ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 - ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE - (None, None, False), # 5 undefined - ('readI16', 'writeI16', False), # 6 TType.I16 - (None, None, False), # 7 undefined - ('readI32', 'writeI32', False), # 8 TType.I32 - (None, None, False), # 9 undefined - ('readI64', 'writeI64', False), # 10 TType.I64 - ('readString', 'writeString', False), # 11 TType.STRING and UTF7 - ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT - ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP - ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET - ('readContainerList', 'writeContainerList', True), # 15 TType.LIST - (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? - (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? - ) - - def _ttype_handlers(self, ttype, spec): - if spec == 'BINARY': - if ttype != TType.STRING: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid binary field type %d' % ttype) - return ('readBinary', 'writeBinary', False) - if sys.version_info[0] == 2 and spec == 'UTF8': - if ttype != TType.STRING: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid string field type %d' % ttype) - return ('readUtf8', 'writeUtf8', False) - return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False) - - def _read_by_ttype(self, ttype, spec, espec): - reader_name, _, is_container = self._ttype_handlers(ttype, spec) - if reader_name is None: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid type %d' % (ttype)) - reader_func = getattr(self, reader_name) - read = (lambda: reader_func(espec)) if is_container else reader_func - while True: - yield read() - - def readFieldByTType(self, ttype, spec): - return self._read_by_ttype(ttype, spec, spec).next() - - def readContainerList(self, spec): - ttype, tspec, is_immutable = spec - (list_type, list_len) = self.readListBegin() - # TODO: compare types we just decoded with thrift_spec - elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len) - results = (tuple if is_immutable else list)(elems) - self.readListEnd() - return results - - def readContainerSet(self, spec): - ttype, tspec, is_immutable = spec - (set_type, set_len) = self.readSetBegin() - # TODO: compare types we just decoded with thrift_spec - elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len) - results = (frozenset if is_immutable else set)(elems) - self.readSetEnd() - return results - - def readContainerStruct(self, spec): - (obj_class, obj_spec) = spec - obj = obj_class() - obj.read(self) - return obj - - def readContainerMap(self, spec): - ktype, kspec, vtype, vspec, is_immutable = spec - (map_ktype, map_vtype, map_len) = self.readMapBegin() - # TODO: compare types we just decoded with thrift_spec and - # abort/skip if types disagree - keys = self._read_by_ttype(ktype, spec, kspec) - vals = self._read_by_ttype(vtype, spec, vspec) - keyvals = islice(zip(keys, vals), map_len) - results = (TFrozenDict if is_immutable else dict)(keyvals) - self.readMapEnd() - return results - - def readStruct(self, obj, thrift_spec, is_immutable=False): - if is_immutable: - fields = {} - self.readStructBegin() - while True: - (fname, ftype, fid) = self.readFieldBegin() - if ftype == TType.STOP: - break - try: - field = thrift_spec[fid] - except IndexError: - self.skip(ftype) - else: - if field is not None and ftype == field[1]: - fname = field[2] - fspec = field[3] - val = self.readFieldByTType(ftype, fspec) - if is_immutable: - fields[fname] = val - else: - setattr(obj, fname, val) - else: - self.skip(ftype) - self.readFieldEnd() - self.readStructEnd() - if is_immutable: - return obj(**fields) - - def writeContainerStruct(self, val, spec): - val.write(self) - - def writeContainerList(self, val, spec): - ttype, tspec, _ = spec - self.writeListBegin(ttype, len(val)) - for _ in self._write_by_ttype(ttype, val, spec, tspec): - pass - self.writeListEnd() - - def writeContainerSet(self, val, spec): - ttype, tspec, _ = spec - self.writeSetBegin(ttype, len(val)) - for _ in self._write_by_ttype(ttype, val, spec, tspec): - pass - self.writeSetEnd() - - def writeContainerMap(self, val, spec): - ktype, kspec, vtype, vspec, _ = spec - self.writeMapBegin(ktype, vtype, len(val)) - for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec), - self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)): - pass - self.writeMapEnd() - - def writeStruct(self, obj, thrift_spec): - self.writeStructBegin(obj.__class__.__name__) - for field in thrift_spec: - if field is None: - continue - fname = field[2] - val = getattr(obj, fname) - if val is None: - # skip writing out unset fields - continue - fid = field[0] - ftype = field[1] - fspec = field[3] - self.writeFieldBegin(fname, ftype, fid) - self.writeFieldByTType(ftype, val, fspec) - self.writeFieldEnd() - self.writeFieldStop() - self.writeStructEnd() - - def _write_by_ttype(self, ttype, vals, spec, espec): - _, writer_name, is_container = self._ttype_handlers(ttype, spec) - writer_func = getattr(self, writer_name) - write = (lambda v: writer_func(v, espec)) if is_container else writer_func - for v in vals: - yield write(v) - - def writeFieldByTType(self, ttype, val, spec): - self._write_by_ttype(ttype, [val], spec, spec).next() + return + elif ttype == TType.BOOL: + self.readBool() + elif ttype == TType.BYTE: + self.readByte() + elif ttype == TType.I16: + self.readI16() + elif ttype == TType.I32: + self.readI32() + elif ttype == TType.I64: + self.readI64() + elif ttype == TType.DOUBLE: + self.readDouble() + elif ttype == TType.STRING: + self.readString() + elif ttype == TType.STRUCT: + name = self.readStructBegin() + while True: + (name, ttype, id) = self.readFieldBegin() + if ttype == TType.STOP: + break + self.skip(ttype) + self.readFieldEnd() + self.readStructEnd() + elif ttype == TType.MAP: + (ktype, vtype, size) = self.readMapBegin() + for i in range(size): + self.skip(ktype) + self.skip(vtype) + self.readMapEnd() + elif ttype == TType.SET: + (etype, size) = self.readSetBegin() + for i in range(size): + self.skip(etype) + self.readSetEnd() + elif ttype == TType.LIST: + (etype, size) = self.readListBegin() + for i in range(size): + self.skip(etype) + self.readListEnd() + + # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) + _TTYPE_HANDLERS = ( + (None, None, False), # 0 TType.STOP + (None, None, False), # 1 TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 TType.BOOL + ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE + (None, None, False), # 5 undefined + ('readI16', 'writeI16', False), # 6 TType.I16 + (None, None, False), # 7 undefined + ('readI32', 'writeI32', False), # 8 TType.I32 + (None, None, False), # 9 undefined + ('readI64', 'writeI64', False), # 10 TType.I64 + ('readString', 'writeString', False), # 11 TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET + ('readContainerList', 'writeContainerList', True), # 15 TType.LIST + (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? + (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? + ) + + def _ttype_handlers(self, ttype, spec): + if spec == 'BINARY': + if ttype != TType.STRING: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid binary field type %d' % ttype) + return ('readBinary', 'writeBinary', False) + if sys.version_info[0] == 2 and spec == 'UTF8': + if ttype != TType.STRING: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid string field type %d' % ttype) + return ('readUtf8', 'writeUtf8', False) + return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False) + + def _read_by_ttype(self, ttype, spec, espec): + reader_name, _, is_container = self._ttype_handlers(ttype, spec) + if reader_name is None: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid type %d' % (ttype)) + reader_func = getattr(self, reader_name) + read = (lambda: reader_func(espec)) if is_container else reader_func + while True: + yield read() + + def readFieldByTType(self, ttype, spec): + return self._read_by_ttype(ttype, spec, spec).next() + + def readContainerList(self, spec): + ttype, tspec, is_immutable = spec + (list_type, list_len) = self.readListBegin() + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len) + results = (tuple if is_immutable else list)(elems) + self.readListEnd() + return results + + def readContainerSet(self, spec): + ttype, tspec, is_immutable = spec + (set_type, set_len) = self.readSetBegin() + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len) + results = (frozenset if is_immutable else set)(elems) + self.readSetEnd() + return results + + def readContainerStruct(self, spec): + (obj_class, obj_spec) = spec + obj = obj_class() + obj.read(self) + return obj + + def readContainerMap(self, spec): + ktype, kspec, vtype, vspec, is_immutable = spec + (map_ktype, map_vtype, map_len) = self.readMapBegin() + # TODO: compare types we just decoded with thrift_spec and + # abort/skip if types disagree + keys = self._read_by_ttype(ktype, spec, kspec) + vals = self._read_by_ttype(vtype, spec, vspec) + keyvals = islice(zip(keys, vals), map_len) + results = (TFrozenDict if is_immutable else dict)(keyvals) + self.readMapEnd() + return results + + def readStruct(self, obj, thrift_spec, is_immutable=False): + if is_immutable: + fields = {} + self.readStructBegin() + while True: + (fname, ftype, fid) = self.readFieldBegin() + if ftype == TType.STOP: + break + try: + field = thrift_spec[fid] + except IndexError: + self.skip(ftype) + else: + if field is not None and ftype == field[1]: + fname = field[2] + fspec = field[3] + val = self.readFieldByTType(ftype, fspec) + if is_immutable: + fields[fname] = val + else: + setattr(obj, fname, val) + else: + self.skip(ftype) + self.readFieldEnd() + self.readStructEnd() + if is_immutable: + return obj(**fields) + + def writeContainerStruct(self, val, spec): + val.write(self) + + def writeContainerList(self, val, spec): + ttype, tspec, _ = spec + self.writeListBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass + self.writeListEnd() + + def writeContainerSet(self, val, spec): + ttype, tspec, _ = spec + self.writeSetBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass + self.writeSetEnd() + + def writeContainerMap(self, val, spec): + ktype, kspec, vtype, vspec, _ = spec + self.writeMapBegin(ktype, vtype, len(val)) + for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec), + self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)): + pass + self.writeMapEnd() + + def writeStruct(self, obj, thrift_spec): + self.writeStructBegin(obj.__class__.__name__) + for field in thrift_spec: + if field is None: + continue + fname = field[2] + val = getattr(obj, fname) + if val is None: + # skip writing out unset fields + continue + fid = field[0] + ftype = field[1] + fspec = field[3] + self.writeFieldBegin(fname, ftype, fid) + self.writeFieldByTType(ftype, val, fspec) + self.writeFieldEnd() + self.writeFieldStop() + self.writeStructEnd() + + def _write_by_ttype(self, ttype, vals, spec, espec): + _, writer_name, is_container = self._ttype_handlers(ttype, spec) + writer_func = getattr(self, writer_name) + write = (lambda v: writer_func(v, espec)) if is_container else writer_func + for v in vals: + yield write(v) + + def writeFieldByTType(self, ttype, val, spec): + self._write_by_ttype(ttype, [val], spec, spec).next() def checkIntegerLimits(i, bits): @@ -408,10 +408,10 @@ def checkIntegerLimits(i, bits): raise TProtocolException(TProtocolException.INVALID_DATA, "i32 requires -2147483648 <= number <= 2147483647") elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807): - raise TProtocolException(TProtocolException.INVALID_DATA, - "i64 requires -9223372036854775808 <= number <= 9223372036854775807") + raise TProtocolException(TProtocolException.INVALID_DATA, + "i64 requires -9223372036854775808 <= number <= 9223372036854775807") class TProtocolFactory(object): - def getProtocol(self, trans): - pass + def getProtocol(self, trans): + pass diff --git a/lib/py/src/protocol/TProtocolDecorator.py b/lib/py/src/protocol/TProtocolDecorator.py index bf50bfad8..8b270a466 100644 --- a/lib/py/src/protocol/TProtocolDecorator.py +++ b/lib/py/src/protocol/TProtocolDecorator.py @@ -17,26 +17,34 @@ # under the License. # +import types + from thrift.protocol.TProtocol import TProtocolBase -from types import * + class TProtocolDecorator(): - def __init__(self, protocol): - TProtocolBase(protocol) - self.protocol = protocol + def __init__(self, protocol): + TProtocolBase(protocol) + self.protocol = protocol - def __getattr__(self, name): - if hasattr(self.protocol, name): - member = getattr(self.protocol, name) - if type(member) in [MethodType, FunctionType, LambdaType, BuiltinFunctionType, BuiltinMethodType]: - return lambda *args, **kwargs: self._wrap(member, args, kwargs) - else: - return member - raise AttributeError(name) + def __getattr__(self, name): + if hasattr(self.protocol, name): + member = getattr(self.protocol, name) + if type(member) in [ + types.MethodType, + types.FunctionType, + types.LambdaType, + types.BuiltinFunctionType, + types.BuiltinMethodType, + ]: + return lambda *args, **kwargs: self._wrap(member, args, kwargs) + else: + return member + raise AttributeError(name) - def _wrap(self, func, args, kwargs): - if type(func) == MethodType: - result = func(*args, **kwargs) - else: - result = func(self.protocol, *args, **kwargs) - return result + def _wrap(self, func, args, kwargs): + if isinstance(func, types.MethodType): + result = func(*args, **kwargs) + else: + result = func(self.protocol, *args, **kwargs) + return result diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py index 7eefb458a..7148f66b3 100644 --- a/lib/py/src/protocol/__init__.py +++ b/lib/py/src/protocol/__init__.py @@ -17,4 +17,5 @@ # under the License. # -__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol'] +__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', + 'TJSONProtocol', 'TProtocol'] diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py index bf3b0e342..1b501a7aa 100644 --- a/lib/py/src/server/THttpServer.py +++ b/lib/py/src/server/THttpServer.py @@ -24,64 +24,64 @@ from thrift.transport import TTransport class ResponseException(Exception): - """Allows handlers to override the HTTP response + """Allows handlers to override the HTTP response - Normally, THttpServer always sends a 200 response. If a handler wants - to override this behavior (e.g., to simulate a misconfigured or - overloaded web server during testing), it can raise a ResponseException. - The function passed to the constructor will be called with the - RequestHandler as its only argument. - """ - def __init__(self, handler): - self.handler = handler + Normally, THttpServer always sends a 200 response. If a handler wants + to override this behavior (e.g., to simulate a misconfigured or + overloaded web server during testing), it can raise a ResponseException. + The function passed to the constructor will be called with the + RequestHandler as its only argument. + """ + def __init__(self, handler): + self.handler = handler class THttpServer(TServer.TServer): - """A simple HTTP-based Thrift server - - This class is not very performant, but it is useful (for example) for - acting as a mock version of an Apache-based PHP Thrift endpoint. - """ - def __init__(self, - processor, - server_address, - inputProtocolFactory, - outputProtocolFactory=None, - server_class=BaseHTTPServer.HTTPServer): - """Set up protocol factories and HTTP server. + """A simple HTTP-based Thrift server - See BaseHTTPServer for server_address. - See TServer for protocol factories. + This class is not very performant, but it is useful (for example) for + acting as a mock version of an Apache-based PHP Thrift endpoint. """ - if outputProtocolFactory is None: - outputProtocolFactory = inputProtocolFactory + def __init__(self, + processor, + server_address, + inputProtocolFactory, + outputProtocolFactory=None, + server_class=BaseHTTPServer.HTTPServer): + """Set up protocol factories and HTTP server. + + See BaseHTTPServer for server_address. + See TServer for protocol factories. + """ + if outputProtocolFactory is None: + outputProtocolFactory = inputProtocolFactory - TServer.TServer.__init__(self, processor, None, None, None, - inputProtocolFactory, outputProtocolFactory) + TServer.TServer.__init__(self, processor, None, None, None, + inputProtocolFactory, outputProtocolFactory) - thttpserver = self + thttpserver = self - class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): - def do_POST(self): - # Don't care about the request path. - itrans = TTransport.TFileObjectTransport(self.rfile) - otrans = TTransport.TFileObjectTransport(self.wfile) - itrans = TTransport.TBufferedTransport( - itrans, int(self.headers['Content-Length'])) - otrans = TTransport.TMemoryBuffer() - iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) - oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) - try: - thttpserver.processor.process(iprot, oprot) - except ResponseException as exn: - exn.handler(self) - else: - self.send_response(200) - self.send_header("content-type", "application/x-thrift") - self.end_headers() - self.wfile.write(otrans.getvalue()) + class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): + def do_POST(self): + # Don't care about the request path. + itrans = TTransport.TFileObjectTransport(self.rfile) + otrans = TTransport.TFileObjectTransport(self.wfile) + itrans = TTransport.TBufferedTransport( + itrans, int(self.headers['Content-Length'])) + otrans = TTransport.TMemoryBuffer() + iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) + oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) + try: + thttpserver.processor.process(iprot, oprot) + except ResponseException as exn: + exn.handler(self) + else: + self.send_response(200) + self.send_header("content-type", "application/x-thrift") + self.end_headers() + self.wfile.write(otrans.getvalue()) - self.httpd = server_class(server_address, RequestHander) + self.httpd = server_class(server_address, RequestHander) - def serve(self): - self.httpd.serve_forever() + def serve(self): + self.httpd.serve_forever() diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py index a930a8091..87031c137 100644 --- a/lib/py/src/server/TNonblockingServer.py +++ b/lib/py/src/server/TNonblockingServer.py @@ -24,13 +24,12 @@ only from the main thread. The thread poool should be sized for concurrent tasks, not maximum connections """ -import threading -import socket -import select -import struct import logging -logger = logging.getLogger(__name__) +import select +import socket +import struct +import threading from six.moves import queue @@ -39,6 +38,8 @@ from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory __all__ = ['TNonblockingServer'] +logger = logging.getLogger(__name__) + class Worker(threading.Thread): """Worker is a small helper to process incoming connection.""" @@ -127,7 +128,7 @@ class Connection(object): self.len, = struct.unpack('!i', self.message) if self.len < 0: logger.error("negative frame size, it seems client " - "doesn't use FramedTransport") + "doesn't use FramedTransport") self.close() elif self.len == 0: logger.error("empty frame, it's really strange") @@ -149,7 +150,7 @@ class Connection(object): read = self.socket.recv(self.len - len(self.message)) if len(read) == 0: logger.error("can't read frame from socket (get %d of " - "%d bytes)" % (len(self.message), self.len)) + "%d bytes)" % (len(self.message), self.len)) self.close() return self.message += read diff --git a/lib/py/src/server/TProcessPoolServer.py b/lib/py/src/server/TProcessPoolServer.py index b2c2308a9..fe6dc8162 100644 --- a/lib/py/src/server/TProcessPoolServer.py +++ b/lib/py/src/server/TProcessPoolServer.py @@ -19,13 +19,14 @@ import logging -logger = logging.getLogger(__name__) -from multiprocessing import Process, Value, Condition, reduction +from multiprocessing import Process, Value, Condition from .TServer import TServer from thrift.transport.TTransport import TTransportException +logger = logging.getLogger(__name__) + class TProcessPoolServer(TServer): """Server with a fixed size pool of worker subprocesses to service requests @@ -59,7 +60,7 @@ class TProcessPoolServer(TServer): try: client = self.serverTransport.accept() if not client: - continue + continue self.serveClient(client) except (KeyboardInterrupt, SystemExit): return 0 @@ -76,7 +77,7 @@ class TProcessPoolServer(TServer): try: while True: self.processor.process(iprot, oprot) - except TTransportException as tx: + except TTransportException: pass except Exception as x: logger.exception(x) diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py index 30f063b43..d5d9c98a9 100644 --- a/lib/py/src/server/TServer.py +++ b/lib/py/src/server/TServer.py @@ -18,262 +18,259 @@ # from six.moves import queue +import logging import os -import sys import threading -import traceback - -import logging -logger = logging.getLogger(__name__) -from thrift.Thrift import TProcessor from thrift.protocol import TBinaryProtocol from thrift.transport import TTransport +logger = logging.getLogger(__name__) + class TServer(object): - """Base interface for a server, which must have a serve() method. - - Three constructors for all servers: - 1) (processor, serverTransport) - 2) (processor, serverTransport, transportFactory, protocolFactory) - 3) (processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory) - """ - def __init__(self, *args): - if (len(args) == 2): - self.__initArgs__(args[0], args[1], - TTransport.TTransportFactoryBase(), - TTransport.TTransportFactoryBase(), - TBinaryProtocol.TBinaryProtocolFactory(), - TBinaryProtocol.TBinaryProtocolFactory()) - elif (len(args) == 4): - self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) - elif (len(args) == 6): - self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) - - def __initArgs__(self, processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory): - self.processor = processor - self.serverTransport = serverTransport - self.inputTransportFactory = inputTransportFactory - self.outputTransportFactory = outputTransportFactory - self.inputProtocolFactory = inputProtocolFactory - self.outputProtocolFactory = outputProtocolFactory - - def serve(self): - pass + """Base interface for a server, which must have a serve() method. + + Three constructors for all servers: + 1) (processor, serverTransport) + 2) (processor, serverTransport, transportFactory, protocolFactory) + 3) (processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory) + """ + def __init__(self, *args): + if (len(args) == 2): + self.__initArgs__(args[0], args[1], + TTransport.TTransportFactoryBase(), + TTransport.TTransportFactoryBase(), + TBinaryProtocol.TBinaryProtocolFactory(), + TBinaryProtocol.TBinaryProtocolFactory()) + elif (len(args) == 4): + self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) + elif (len(args) == 6): + self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + + def __initArgs__(self, processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory): + self.processor = processor + self.serverTransport = serverTransport + self.inputTransportFactory = inputTransportFactory + self.outputTransportFactory = outputTransportFactory + self.inputProtocolFactory = inputProtocolFactory + self.outputProtocolFactory = outputProtocolFactory + + def serve(self): + pass class TSimpleServer(TServer): - """Simple single-threaded server that just pumps around one transport.""" - - def __init__(self, *args): - TServer.__init__(self, *args) - - def serve(self): - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - if not client: - continue - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: + """Simple single-threaded server that just pumps around one transport.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + + def serve(self): + self.serverTransport.listen() while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException as tx: - pass - except Exception as x: - logger.exception(x) + client = self.serverTransport.accept() + if not client: + continue + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) - itrans.close() - otrans.close() + itrans.close() + otrans.close() class TThreadedServer(TServer): - """Threaded server that spawns a new thread per each connection.""" - - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.daemon = kwargs.get("daemon", False) - - def serve(self): - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - if not client: - continue - t = threading.Thread(target=self.handle, args=(client,)) - t.setDaemon(self.daemon) - t.start() - except KeyboardInterrupt: - raise - except Exception as x: - logger.exception(x) - - def handle(self, client): - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException as tx: - pass - except Exception as x: - logger.exception(x) - - itrans.close() - otrans.close() + """Threaded server that spawns a new thread per each connection.""" + + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.daemon = kwargs.get("daemon", False) + + def serve(self): + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + if not client: + continue + t = threading.Thread(target=self.handle, args=(client,)) + t.setDaemon(self.daemon) + t.start() + except KeyboardInterrupt: + raise + except Exception as x: + logger.exception(x) + + def handle(self, client): + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + otrans.close() class TThreadPoolServer(TServer): - """Server with a fixed size pool of threads which service requests.""" - - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.clients = queue.Queue() - self.threads = 10 - self.daemon = kwargs.get("daemon", False) - - def setNumThreads(self, num): - """Set the number of worker threads that should be created""" - self.threads = num - - def serveThread(self): - """Loop around getting clients from the shared queue and process them.""" - while True: - try: - client = self.clients.get() - self.serveClient(client) - except Exception as x: - logger.exception(x) - - def serveClient(self, client): - """Process input/output from a client for as long as possible""" - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException as tx: - pass - except Exception as x: - logger.exception(x) - - itrans.close() - otrans.close() - - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - for i in range(self.threads): - try: - t = threading.Thread(target=self.serveThread) - t.setDaemon(self.daemon) - t.start() - except Exception as x: - logger.exception(x) - - # Pump the socket for clients - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - if not client: - continue - self.clients.put(client) - except Exception as x: - logger.exception(x) + """Server with a fixed size pool of threads which service requests.""" + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.clients = queue.Queue() + self.threads = 10 + self.daemon = kwargs.get("daemon", False) -class TForkingServer(TServer): - """A Thrift server that forks a new process for each request - - This is more scalable than the threaded server as it does not cause - GIL contention. - - Note that this has different semantics from the threading server. - Specifically, updates to shared variables will no longer be shared. - It will also not work on windows. - - This code is heavily inspired by SocketServer.ForkingMixIn in the - Python stdlib. - """ - def __init__(self, *args): - TServer.__init__(self, *args) - self.children = [] - - def serve(self): - def try_close(file): - try: - file.close() - except IOError as e: - logger.warning(e, exc_info=True) - - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - if not client: - continue - try: - pid = os.fork() - - if pid: # parent - # add before collect, otherwise you race w/ waitpid - self.children.append(pid) - self.collect_children() - - # Parent must close socket or the connection may not get - # closed promptly - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - try_close(itrans) - try_close(otrans) - else: - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - - ecode = 0 - try: + def setNumThreads(self, num): + """Set the number of worker threads that should be created""" + self.threads = num + + def serveThread(self): + """Loop around getting clients from the shared queue and process them.""" + while True: try: - while True: + client = self.clients.get() + self.serveClient(client) + except Exception as x: + logger.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: self.processor.process(iprot, oprot) - except TTransport.TTransportException: - pass - except Exception as e: - logger.exception(e) - ecode = 1 - finally: - try_close(itrans) - try_close(otrans) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) - os._exit(ecode) + itrans.close() + otrans.close() - except TTransport.TTransportException: - pass - except Exception as x: - logger.exception(x) - - def collect_children(self): - while self.children: - try: - pid, status = os.waitpid(0, os.WNOHANG) - except os.error: - pid = None - - if pid: - self.children.remove(pid) - else: - break + def serve(self): + """Start a fixed number of worker threads and put client into a queue""" + for i in range(self.threads): + try: + t = threading.Thread(target=self.serveThread) + t.setDaemon(self.daemon) + t.start() + except Exception as x: + logger.exception(x) + + # Pump the socket for clients + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + if not client: + continue + self.clients.put(client) + except Exception as x: + logger.exception(x) + + +class TForkingServer(TServer): + """A Thrift server that forks a new process for each request + + This is more scalable than the threaded server as it does not cause + GIL contention. + + Note that this has different semantics from the threading server. + Specifically, updates to shared variables will no longer be shared. + It will also not work on windows. + + This code is heavily inspired by SocketServer.ForkingMixIn in the + Python stdlib. + """ + def __init__(self, *args): + TServer.__init__(self, *args) + self.children = [] + + def serve(self): + def try_close(file): + try: + file.close() + except IOError as e: + logger.warning(e, exc_info=True) + + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + if not client: + continue + try: + pid = os.fork() + + if pid: # parent + # add before collect, otherwise you race w/ waitpid + self.children.append(pid) + self.collect_children() + + # Parent must close socket or the connection may not get + # closed promptly + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + try_close(itrans) + try_close(otrans) + else: + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + ecode = 0 + try: + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as e: + logger.exception(e) + ecode = 1 + finally: + try_close(itrans) + try_close(otrans) + + os._exit(ecode) + + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + def collect_children(self): + while self.children: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except os.error: + pid = None + + if pid: + self.children.remove(pid) + else: + break diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py index 5abd41c70..95f118cb4 100644 --- a/lib/py/src/transport/THttpClient.py +++ b/lib/py/src/transport/THttpClient.py @@ -26,130 +26,130 @@ import warnings from six.moves import urllib from six.moves import http_client -from .TTransport import * +from .TTransport import TTransportBase import six class THttpClient(TTransportBase): - """Http implementation of TTransport base.""" - - def __init__(self, uri_or_host, port=None, path=None): - """THttpClient supports two different types constructor parameters. - - THttpClient(host, port, path) - deprecated - THttpClient(uri) - - Only the second supports https. - """ - if port is not None: - warnings.warn( - "Please use the THttpClient('http://host:port/path') syntax", - DeprecationWarning, - stacklevel=2) - self.host = uri_or_host - self.port = port - assert path - self.path = path - self.scheme = 'http' - else: - parsed = urllib.parse.urlparse(uri_or_host) - self.scheme = parsed.scheme - assert self.scheme in ('http', 'https') - if self.scheme == 'http': - self.port = parsed.port or http_client.HTTP_PORT - elif self.scheme == 'https': - self.port = parsed.port or http_client.HTTPS_PORT - self.host = parsed.hostname - self.path = parsed.path - if parsed.query: - self.path += '?%s' % parsed.query - self.__wbuf = BytesIO() - self.__http = None - self.__http_response = None - self.__timeout = None - self.__custom_headers = None - - def open(self): - if self.scheme == 'http': - self.__http = http_client.HTTPConnection(self.host, self.port) - else: - self.__http = http_client.HTTPSConnection(self.host, self.port) - - def close(self): - self.__http.close() - self.__http = None - self.__http_response = None - - def isOpen(self): - return self.__http is not None - - def setTimeout(self, ms): - if not hasattr(socket, 'getdefaulttimeout'): - raise NotImplementedError - - if ms is None: - self.__timeout = None - else: - self.__timeout = ms / 1000.0 - - def setCustomHeaders(self, headers): - self.__custom_headers = headers - - def read(self, sz): - return self.__http_response.read(sz) - - def write(self, buf): - self.__wbuf.write(buf) - - def __withTimeout(f): - def _f(*args, **kwargs): - orig_timeout = socket.getdefaulttimeout() - socket.setdefaulttimeout(args[0].__timeout) - try: - result = f(*args, **kwargs) - finally: - socket.setdefaulttimeout(orig_timeout) - return result - return _f - - def flush(self): - if self.isOpen(): - self.close() - self.open() - - # Pull data out of buffer - data = self.__wbuf.getvalue() - self.__wbuf = BytesIO() - - # HTTP request - self.__http.putrequest('POST', self.path) - - # Write headers - self.__http.putheader('Content-Type', 'application/x-thrift') - self.__http.putheader('Content-Length', str(len(data))) - - if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: - user_agent = 'Python/THttpClient' - script = os.path.basename(sys.argv[0]) - if script: - user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script)) - self.__http.putheader('User-Agent', user_agent) - - if self.__custom_headers: - for key, val in six.iteritems(self.__custom_headers): - self.__http.putheader(key, val) - - self.__http.endheaders() - - # Write payload - self.__http.send(data) - - # Get reply to flush the request - self.__http_response = self.__http.getresponse() - self.code = self.__http_response.status - self.message = self.__http_response.reason - self.headers = self.__http_response.msg - - # Decorate if we know how to timeout - if hasattr(socket, 'getdefaulttimeout'): - flush = __withTimeout(flush) + """Http implementation of TTransport base.""" + + def __init__(self, uri_or_host, port=None, path=None): + """THttpClient supports two different types constructor parameters. + + THttpClient(host, port, path) - deprecated + THttpClient(uri) + + Only the second supports https. + """ + if port is not None: + warnings.warn( + "Please use the THttpClient('http://host:port/path') syntax", + DeprecationWarning, + stacklevel=2) + self.host = uri_or_host + self.port = port + assert path + self.path = path + self.scheme = 'http' + else: + parsed = urllib.parse.urlparse(uri_or_host) + self.scheme = parsed.scheme + assert self.scheme in ('http', 'https') + if self.scheme == 'http': + self.port = parsed.port or http_client.HTTP_PORT + elif self.scheme == 'https': + self.port = parsed.port or http_client.HTTPS_PORT + self.host = parsed.hostname + self.path = parsed.path + if parsed.query: + self.path += '?%s' % parsed.query + self.__wbuf = BytesIO() + self.__http = None + self.__http_response = None + self.__timeout = None + self.__custom_headers = None + + def open(self): + if self.scheme == 'http': + self.__http = http_client.HTTPConnection(self.host, self.port) + else: + self.__http = http_client.HTTPSConnection(self.host, self.port) + + def close(self): + self.__http.close() + self.__http = None + self.__http_response = None + + def isOpen(self): + return self.__http is not None + + def setTimeout(self, ms): + if not hasattr(socket, 'getdefaulttimeout'): + raise NotImplementedError + + if ms is None: + self.__timeout = None + else: + self.__timeout = ms / 1000.0 + + def setCustomHeaders(self, headers): + self.__custom_headers = headers + + def read(self, sz): + return self.__http_response.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def __withTimeout(f): + def _f(*args, **kwargs): + orig_timeout = socket.getdefaulttimeout() + socket.setdefaulttimeout(args[0].__timeout) + try: + result = f(*args, **kwargs) + finally: + socket.setdefaulttimeout(orig_timeout) + return result + return _f + + def flush(self): + if self.isOpen(): + self.close() + self.open() + + # Pull data out of buffer + data = self.__wbuf.getvalue() + self.__wbuf = BytesIO() + + # HTTP request + self.__http.putrequest('POST', self.path) + + # Write headers + self.__http.putheader('Content-Type', 'application/x-thrift') + self.__http.putheader('Content-Length', str(len(data))) + + if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: + user_agent = 'Python/THttpClient' + script = os.path.basename(sys.argv[0]) + if script: + user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script)) + self.__http.putheader('User-Agent', user_agent) + + if self.__custom_headers: + for key, val in six.iteritems(self.__custom_headers): + self.__http.putheader(key, val) + + self.__http.endheaders() + + # Write payload + self.__http.send(data) + + # Get reply to flush the request + self.__http_response = self.__http.getresponse() + self.code = self.__http_response.status + self.message = self.__http_response.reason + self.headers = self.__http_response.msg + + # Decorate if we know how to timeout + if hasattr(socket, 'getdefaulttimeout'): + flush = __withTimeout(flush) diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py index 9be0912f9..3f1a909df 100644 --- a/lib/py/src/transport/TSSLSocket.py +++ b/lib/py/src/transport/TSSLSocket.py @@ -32,345 +32,345 @@ warnings.filterwarnings('default', category=DeprecationWarning, module=__name__) class TSSLBase(object): - # SSLContext is not available for Python < 2.7.9 - _has_ssl_context = sys.hexversion >= 0x020709F0 - - # ciphers argument is not available for Python < 2.7.0 - _has_ciphers = sys.hexversion >= 0x020700F0 - - # For pythoon >= 2.7.9, use latest TLS that both client and server supports. - # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3. - # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable. - _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1 - - def _init_context(self, ssl_version): - if self._has_ssl_context: - self._context = ssl.SSLContext(ssl_version) - if self._context.protocol == ssl.PROTOCOL_SSLv23: - self._context.options |= ssl.OP_NO_SSLv2 - self._context.options |= ssl.OP_NO_SSLv3 - else: - self._context = None - self._ssl_version = ssl_version - - @property - def ssl_version(self): - if self._has_ssl_context: - return self.ssl_context.protocol - else: - return self._ssl_version - - @property - def ssl_context(self): - return self._context - - SSL_VERSION = _default_protocol - """ + # SSLContext is not available for Python < 2.7.9 + _has_ssl_context = sys.hexversion >= 0x020709F0 + + # ciphers argument is not available for Python < 2.7.0 + _has_ciphers = sys.hexversion >= 0x020700F0 + + # For pythoon >= 2.7.9, use latest TLS that both client and server supports. + # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3. + # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable. + _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1 + + def _init_context(self, ssl_version): + if self._has_ssl_context: + self._context = ssl.SSLContext(ssl_version) + if self._context.protocol == ssl.PROTOCOL_SSLv23: + self._context.options |= ssl.OP_NO_SSLv2 + self._context.options |= ssl.OP_NO_SSLv3 + else: + self._context = None + self._ssl_version = ssl_version + + @property + def ssl_version(self): + if self._has_ssl_context: + return self.ssl_context.protocol + else: + return self._ssl_version + + @property + def ssl_context(self): + return self._context + + SSL_VERSION = _default_protocol + """ Default SSL version. For backword compatibility, it can be modified. Use __init__ keywoard argument "ssl_version" instead. """ - def _deprecated_arg(self, args, kwargs, pos, key): - if len(args) <= pos: - return - real_pos = pos + 3 - warnings.warn( - '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos, - DeprecationWarning) - if key in kwargs: - raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key)) - kwargs[key] = args[pos] - - def _unix_socket_arg(self, host, port, args, kwargs): - key = 'unix_socket' - if host is None and port is None and len(args) == 1 and key not in kwargs: - kwargs[key] = args[0] - return True - return False - - def __getattr__(self, key): - if key == 'SSL_VERSION': - warnings.warn('Use ssl_version attribute instead.', DeprecationWarning) - return self.ssl_version - - def __init__(self, server_side, host, ssl_opts): - self._server_side = server_side - if TSSLBase.SSL_VERSION != self._default_protocol: - warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning) - self._context = ssl_opts.pop('ssl_context', None) - self._server_hostname = None - if not self._server_side: - self._server_hostname = ssl_opts.pop('server_hostname', host) - if self._context: - self._custom_context = True - if ssl_opts: - raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys())) - if not self._has_ssl_context: - raise ValueError('ssl_context is not available for this version of Python') - else: - self._custom_context = False - ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION) - self._init_context(ssl_version) - self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED) - self.ca_certs = ssl_opts.pop('ca_certs', None) - self.keyfile = ssl_opts.pop('keyfile', None) - self.certfile = ssl_opts.pop('certfile', None) - self.ciphers = ssl_opts.pop('ciphers', None) - - if ssl_opts: - raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys())) - - if self.cert_reqs != ssl.CERT_NONE: - if not self.ca_certs: - raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE') - if not os.access(self.ca_certs, os.R_OK): - raise IOError('Certificate Authority ca_certs file "%s" ' - 'is not readable, cannot validate SSL ' - 'certificates.' % (self.ca_certs)) - - @property - def certfile(self): - return self._certfile - - @certfile.setter - def certfile(self, certfile): - if self._server_side and not certfile: - raise ValueError('certfile is needed for server-side') - if certfile and not os.access(certfile, os.R_OK): - raise IOError('No such certfile found: %s' % (certfile)) - self._certfile = certfile - - def _wrap_socket(self, sock): - if self._has_ssl_context: - if not self._custom_context: - self.ssl_context.verify_mode = self.cert_reqs - if self.certfile: - self.ssl_context.load_cert_chain(self.certfile, self.keyfile) - if self.ciphers: - self.ssl_context.set_ciphers(self.ciphers) - if self.ca_certs: - self.ssl_context.load_verify_locations(self.ca_certs) - return self.ssl_context.wrap_socket(sock, server_side=self._server_side, - server_hostname=self._server_hostname) - else: - ssl_opts = { - 'ssl_version': self._ssl_version, - 'server_side': self._server_side, - 'ca_certs': self.ca_certs, - 'keyfile': self.keyfile, - 'certfile': self.certfile, - 'cert_reqs': self.cert_reqs, - } - if self.ciphers: - if self._has_ciphers: - ssl_opts['ciphers'] = self.ciphers + def _deprecated_arg(self, args, kwargs, pos, key): + if len(args) <= pos: + return + real_pos = pos + 3 + warnings.warn( + '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos, + DeprecationWarning) + if key in kwargs: + raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key)) + kwargs[key] = args[pos] + + def _unix_socket_arg(self, host, port, args, kwargs): + key = 'unix_socket' + if host is None and port is None and len(args) == 1 and key not in kwargs: + kwargs[key] = args[0] + return True + return False + + def __getattr__(self, key): + if key == 'SSL_VERSION': + warnings.warn('Use ssl_version attribute instead.', DeprecationWarning) + return self.ssl_version + + def __init__(self, server_side, host, ssl_opts): + self._server_side = server_side + if TSSLBase.SSL_VERSION != self._default_protocol: + warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning) + self._context = ssl_opts.pop('ssl_context', None) + self._server_hostname = None + if not self._server_side: + self._server_hostname = ssl_opts.pop('server_hostname', host) + if self._context: + self._custom_context = True + if ssl_opts: + raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys())) + if not self._has_ssl_context: + raise ValueError('ssl_context is not available for this version of Python') + else: + self._custom_context = False + ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION) + self._init_context(ssl_version) + self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED) + self.ca_certs = ssl_opts.pop('ca_certs', None) + self.keyfile = ssl_opts.pop('keyfile', None) + self.certfile = ssl_opts.pop('certfile', None) + self.ciphers = ssl_opts.pop('ciphers', None) + + if ssl_opts: + raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys())) + + if self.cert_reqs != ssl.CERT_NONE: + if not self.ca_certs: + raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE') + if not os.access(self.ca_certs, os.R_OK): + raise IOError('Certificate Authority ca_certs file "%s" ' + 'is not readable, cannot validate SSL ' + 'certificates.' % (self.ca_certs)) + + @property + def certfile(self): + return self._certfile + + @certfile.setter + def certfile(self, certfile): + if self._server_side and not certfile: + raise ValueError('certfile is needed for server-side') + if certfile and not os.access(certfile, os.R_OK): + raise IOError('No such certfile found: %s' % (certfile)) + self._certfile = certfile + + def _wrap_socket(self, sock): + if self._has_ssl_context: + if not self._custom_context: + self.ssl_context.verify_mode = self.cert_reqs + if self.certfile: + self.ssl_context.load_cert_chain(self.certfile, self.keyfile) + if self.ciphers: + self.ssl_context.set_ciphers(self.ciphers) + if self.ca_certs: + self.ssl_context.load_verify_locations(self.ca_certs) + return self.ssl_context.wrap_socket(sock, server_side=self._server_side, + server_hostname=self._server_hostname) else: - logger.warning('ciphers is specified but ignored due to old Python version') - return ssl.wrap_socket(sock, **ssl_opts) + ssl_opts = { + 'ssl_version': self._ssl_version, + 'server_side': self._server_side, + 'ca_certs': self.ca_certs, + 'keyfile': self.keyfile, + 'certfile': self.certfile, + 'cert_reqs': self.cert_reqs, + } + if self.ciphers: + if self._has_ciphers: + ssl_opts['ciphers'] = self.ciphers + else: + logger.warning('ciphers is specified but ignored due to old Python version') + return ssl.wrap_socket(sock, **ssl_opts) class TSSLSocket(TSocket.TSocket, TSSLBase): - """ - SSL implementation of TSocket - - This class creates outbound sockets wrapped using the - python standard ssl module for encrypted connections. - """ + """ + SSL implementation of TSocket - # New signature - # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): - # Deprecated signature - # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): - def __init__(self, host='localhost', port=9090, *args, **kwargs): - """Positional arguments: ``host``, ``port``, ``unix_socket`` - - Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, - ``ca_certs``, ``ciphers`` (Python 2.7.0 or later), - ``server_hostname`` (Python 2.7.9 or later) - Passed to ssl.wrap_socket. See ssl.wrap_socket documentation. - - Alternative keywoard arguments: (Python 2.7.9 or later) - ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket - ``server_hostname``: Passed to SSLContext.wrap_socket + This class creates outbound sockets wrapped using the + python standard ssl module for encrypted connections. """ - self.is_valid = False - self.peercert = None - - if args: - if len(args) > 6: - raise TypeError('Too many positional argument') - if not self._unix_socket_arg(host, port, args, kwargs): - self._deprecated_arg(args, kwargs, 0, 'validate') - self._deprecated_arg(args, kwargs, 1, 'ca_certs') - self._deprecated_arg(args, kwargs, 2, 'keyfile') - self._deprecated_arg(args, kwargs, 3, 'certfile') - self._deprecated_arg(args, kwargs, 4, 'unix_socket') - self._deprecated_arg(args, kwargs, 5, 'ciphers') - - validate = kwargs.pop('validate', None) - if validate is not None: - cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE' - warnings.warn( - 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name, - DeprecationWarning) - if 'cert_reqs' in kwargs: - raise TypeError('Cannot specify both validate and cert_reqs') - kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE - - unix_socket = kwargs.pop('unix_socket', None) - TSSLBase.__init__(self, False, host, kwargs) - TSocket.TSocket.__init__(self, host, port, unix_socket) - - @property - def validate(self): - warnings.warn('Use cert_reqs instead', DeprecationWarning) - return self.cert_reqs != ssl.CERT_NONE - - @validate.setter - def validate(self, value): - warnings.warn('Use cert_reqs instead', DeprecationWarning) - self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE - - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - sock_family, sock_type = res[0:2] - ip_port = res[4] - plain_sock = socket.socket(sock_family, sock_type) - self.handle = self._wrap_socket(plain_sock) - self.handle.settimeout(self._timeout) + + # New signature + # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): + # Deprecated signature + # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): + def __init__(self, host='localhost', port=9090, *args, **kwargs): + """Positional arguments: ``host``, ``port``, ``unix_socket`` + + Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, + ``ca_certs``, ``ciphers`` (Python 2.7.0 or later), + ``server_hostname`` (Python 2.7.9 or later) + Passed to ssl.wrap_socket. See ssl.wrap_socket documentation. + + Alternative keywoard arguments: (Python 2.7.9 or later) + ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket + ``server_hostname``: Passed to SSLContext.wrap_socket + """ + self.is_valid = False + self.peercert = None + + if args: + if len(args) > 6: + raise TypeError('Too many positional argument') + if not self._unix_socket_arg(host, port, args, kwargs): + self._deprecated_arg(args, kwargs, 0, 'validate') + self._deprecated_arg(args, kwargs, 1, 'ca_certs') + self._deprecated_arg(args, kwargs, 2, 'keyfile') + self._deprecated_arg(args, kwargs, 3, 'certfile') + self._deprecated_arg(args, kwargs, 4, 'unix_socket') + self._deprecated_arg(args, kwargs, 5, 'ciphers') + + validate = kwargs.pop('validate', None) + if validate is not None: + cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE' + warnings.warn( + 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name, + DeprecationWarning) + if 'cert_reqs' in kwargs: + raise TypeError('Cannot specify both validate and cert_reqs') + kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE + + unix_socket = kwargs.pop('unix_socket', None) + TSSLBase.__init__(self, False, host, kwargs) + TSocket.TSocket.__init__(self, host, port, unix_socket) + + @property + def validate(self): + warnings.warn('Use cert_reqs instead', DeprecationWarning) + return self.cert_reqs != ssl.CERT_NONE + + @validate.setter + def validate(self, value): + warnings.warn('Use cert_reqs instead', DeprecationWarning) + self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE + + def open(self): try: - self.handle.connect(ip_port) + res0 = self._resolveAddr() + for res in res0: + sock_family, sock_type = res[0:2] + ip_port = res[4] + plain_sock = socket.socket(sock_family, sock_type) + self.handle = self._wrap_socket(plain_sock) + self.handle.settimeout(self._timeout) + try: + self.handle.connect(ip_port) + except socket.error as e: + if res is not res0[-1]: + logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True) + continue + else: + raise + break except socket.error as e: - if res is not res0[-1]: - logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True) - continue - else: - raise - break - except socket.error as e: - if self._unix_socket: - message = 'Could not connect to secure socket %s: %s' \ - % (self._unix_socket, e) - else: - message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e) - logger.error('Error while connecting with %s.', ip_port, exc_info=True) - raise TTransportException(type=TTransportException.NOT_OPEN, - message=message) - if self.validate: - self._validate_cert() - - def _validate_cert(self): - """internal method to validate the peer's SSL certificate, and to check the - commonName of the certificate to ensure it matches the hostname we - used to make this connection. Does not support subjectAltName records - in certificates. - - raises TTransportException if the certificate fails validation. - """ - cert = self.handle.getpeercert() - self.peercert = cert - if 'subject' not in cert: - raise TTransportException( - type=TTransportException.NOT_OPEN, - message='No SSL certificate found from %s:%s' % (self.host, self.port)) - fields = cert['subject'] - for field in fields: - # ensure structure we get back is what we expect - if not isinstance(field, tuple): - continue - cert_pair = field[0] - if len(cert_pair) < 2: - continue - cert_key, cert_value = cert_pair[0:2] - if cert_key != 'commonName': - continue - certhost = cert_value - # this check should be performed by some sort of Access Manager - if certhost == self.host: - # success, cert commonName matches desired hostname - self.is_valid = True - return - else: + if self._unix_socket: + message = 'Could not connect to secure socket %s: %s' \ + % (self._unix_socket, e) + else: + message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e) + logger.error('Error while connecting with %s.', ip_port, exc_info=True) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) + if self.validate: + self._validate_cert() + + def _validate_cert(self): + """internal method to validate the peer's SSL certificate, and to check the + commonName of the certificate to ensure it matches the hostname we + used to make this connection. Does not support subjectAltName records + in certificates. + + raises TTransportException if the certificate fails validation. + """ + cert = self.handle.getpeercert() + self.peercert = cert + if 'subject' not in cert: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message='No SSL certificate found from %s:%s' % (self.host, self.port)) + fields = cert['subject'] + for field in fields: + # ensure structure we get back is what we expect + if not isinstance(field, tuple): + continue + cert_pair = field[0] + if len(cert_pair) < 2: + continue + cert_key, cert_value = cert_pair[0:2] + if cert_key != 'commonName': + continue + certhost = cert_value + # this check should be performed by some sort of Access Manager + if certhost == self.host: + # success, cert commonName matches desired hostname + self.is_valid = True + return + else: + raise TTransportException( + type=TTransportException.UNKNOWN, + message='Hostname we connected to "%s" doesn\'t match certificate ' + 'provided commonName "%s"' % (self.host, certhost)) raise TTransportException( - type=TTransportException.UNKNOWN, - message='Hostname we connected to "%s" doesn\'t match certificate ' - 'provided commonName "%s"' % (self.host, certhost)) - raise TTransportException( - type=TTransportException.UNKNOWN, - message='Could not validate SSL certificate from ' - 'host "%s". Cert=%s' % (self.host, cert)) + type=TTransportException.UNKNOWN, + message='Could not validate SSL certificate from ' + 'host "%s". Cert=%s' % (self.host, cert)) class TSSLServerSocket(TSocket.TServerSocket, TSSLBase): - """SSL implementation of TServerSocket - - This uses the ssl module's wrap_socket() method to provide SSL - negotiated encryption. - """ + """SSL implementation of TServerSocket - # New signature - # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): - # Deprecated signature - # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): - def __init__(self, host=None, port=9090, *args, **kwargs): - """Positional arguments: ``host``, ``port``, ``unix_socket`` - - Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, - ``ca_certs``, ``ciphers`` (Python 2.7.0 or later) - See ssl.wrap_socket documentation. - - Alternative keywoard arguments: (Python 2.7.9 or later) - ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket - ``server_hostname``: Passed to SSLContext.wrap_socket - """ - if args: - if len(args) > 3: - raise TypeError('Too many positional argument') - if not self._unix_socket_arg(host, port, args, kwargs): - self._deprecated_arg(args, kwargs, 0, 'certfile') - self._deprecated_arg(args, kwargs, 1, 'unix_socket') - self._deprecated_arg(args, kwargs, 2, 'ciphers') - - if 'ssl_context' not in kwargs: - # Preserve existing behaviors for default values - if 'cert_reqs' not in kwargs: - kwargs['cert_reqs'] = ssl.CERT_NONE - if'certfile' not in kwargs: - kwargs['certfile'] = 'cert.pem' - - unix_socket = kwargs.pop('unix_socket', None) - TSSLBase.__init__(self, True, None, kwargs) - TSocket.TServerSocket.__init__(self, host, port, unix_socket) - - def setCertfile(self, certfile): - """Set or change the server certificate file used to wrap new connections. - - @param certfile: The filename of the server certificate, - i.e. '/etc/certs/server.pem' - @type certfile: str - - Raises an IOError exception if the certfile is not present or unreadable. + This uses the ssl module's wrap_socket() method to provide SSL + negotiated encryption. """ - warnings.warn('Use certfile property instead.', DeprecationWarning) - self.certfile = certfile - - def accept(self): - plain_client, addr = self.handle.accept() - try: - client = self._wrap_socket(plain_client) - except ssl.SSLError: - logger.error('Error while accepting from %s', addr, exc_info=True) - # failed handshake/ssl wrap, close socket to client - plain_client.close() - # raise - # We can't raise the exception, because it kills most TServer derived - # serve() methods. - # Instead, return None, and let the TServer instance deal with it in - # other exception handling. (but TSimpleServer dies anyway) - return None - result = TSocket.TSocket() - result.setHandle(client) - return result + + # New signature + # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): + # Deprecated signature + # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): + def __init__(self, host=None, port=9090, *args, **kwargs): + """Positional arguments: ``host``, ``port``, ``unix_socket`` + + Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, + ``ca_certs``, ``ciphers`` (Python 2.7.0 or later) + See ssl.wrap_socket documentation. + + Alternative keywoard arguments: (Python 2.7.9 or later) + ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket + ``server_hostname``: Passed to SSLContext.wrap_socket + """ + if args: + if len(args) > 3: + raise TypeError('Too many positional argument') + if not self._unix_socket_arg(host, port, args, kwargs): + self._deprecated_arg(args, kwargs, 0, 'certfile') + self._deprecated_arg(args, kwargs, 1, 'unix_socket') + self._deprecated_arg(args, kwargs, 2, 'ciphers') + + if 'ssl_context' not in kwargs: + # Preserve existing behaviors for default values + if 'cert_reqs' not in kwargs: + kwargs['cert_reqs'] = ssl.CERT_NONE + if'certfile' not in kwargs: + kwargs['certfile'] = 'cert.pem' + + unix_socket = kwargs.pop('unix_socket', None) + TSSLBase.__init__(self, True, None, kwargs) + TSocket.TServerSocket.__init__(self, host, port, unix_socket) + + def setCertfile(self, certfile): + """Set or change the server certificate file used to wrap new connections. + + @param certfile: The filename of the server certificate, + i.e. '/etc/certs/server.pem' + @type certfile: str + + Raises an IOError exception if the certfile is not present or unreadable. + """ + warnings.warn('Use certfile property instead.', DeprecationWarning) + self.certfile = certfile + + def accept(self): + plain_client, addr = self.handle.accept() + try: + client = self._wrap_socket(plain_client) + except ssl.SSLError: + logger.error('Error while accepting from %s', addr, exc_info=True) + # failed handshake/ssl wrap, close socket to client + plain_client.close() + # raise + # We can't raise the exception, because it kills most TServer derived + # serve() methods. + # Instead, return None, and let the TServer instance deal with it in + # other exception handling. (but TSimpleServer dies anyway) + return None + result = TSocket.TSocket() + result.setHandle(client) + return result diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py index cb204a4a0..a8ed4b7dc 100644 --- a/lib/py/src/transport/TSocket.py +++ b/lib/py/src/transport/TSocket.py @@ -22,159 +22,159 @@ import os import socket import sys -from .TTransport import * +from .TTransport import TTransportBase, TTransportException, TServerTransportBase class TSocketBase(TTransportBase): - def _resolveAddr(self): - if self._unix_socket is not None: - return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, - self._unix_socket)] - else: - return socket.getaddrinfo(self.host, - self.port, - self._socket_family, - socket.SOCK_STREAM, - 0, - socket.AI_PASSIVE | socket.AI_ADDRCONFIG) - - def close(self): - if self.handle: - self.handle.close() - self.handle = None + def _resolveAddr(self): + if self._unix_socket is not None: + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, + self._unix_socket)] + else: + return socket.getaddrinfo(self.host, + self.port, + self._socket_family, + socket.SOCK_STREAM, + 0, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + + def close(self): + if self.handle: + self.handle.close() + self.handle = None class TSocket(TSocketBase): - """Socket implementation of TTransport base.""" - - def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): - """Initialize a TSocket - - @param host(str) The host to connect to. - @param port(int) The (TCP) port to connect to. - @param unix_socket(str) The filename of a unix socket to connect to. - (host and port will be ignored.) - @param socket_family(int) The socket family to use with this socket. - """ - self.host = host - self.port = port - self.handle = None - self._unix_socket = unix_socket - self._timeout = None - self._socket_family = socket_family - - def setHandle(self, h): - self.handle = h - - def isOpen(self): - return self.handle is not None - - def setTimeout(self, ms): - if ms is None: - self._timeout = None - else: - self._timeout = ms / 1000.0 - - if self.handle is not None: - self.handle.settimeout(self._timeout) - - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - self.handle = socket.socket(res[0], res[1]) - self.handle.settimeout(self._timeout) + """Socket implementation of TTransport base.""" + + def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): + """Initialize a TSocket + + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + (host and port will be ignored.) + @param socket_family(int) The socket family to use with this socket. + """ + self.host = host + self.port = port + self.handle = None + self._unix_socket = unix_socket + self._timeout = None + self._socket_family = socket_family + + def setHandle(self, h): + self.handle = h + + def isOpen(self): + return self.handle is not None + + def setTimeout(self, ms): + if ms is None: + self._timeout = None + else: + self._timeout = ms / 1000.0 + + if self.handle is not None: + self.handle.settimeout(self._timeout) + + def open(self): + try: + res0 = self._resolveAddr() + for res in res0: + self.handle = socket.socket(res[0], res[1]) + self.handle.settimeout(self._timeout) + try: + self.handle.connect(res[4]) + except socket.error as e: + if res is not res0[-1]: + continue + else: + raise e + break + except socket.error as e: + if self._unix_socket: + message = 'Could not connect to socket %s' % self._unix_socket + else: + message = 'Could not connect to %s:%d' % (self.host, self.port) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) + + def read(self, sz): try: - self.handle.connect(res[4]) + buff = self.handle.recv(sz) except socket.error as e: - if res is not res0[-1]: - continue - else: - raise e - break - except socket.error as e: - if self._unix_socket: - message = 'Could not connect to socket %s' % self._unix_socket - else: - message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, - message=message) - - def read(self, sz): - try: - buff = self.handle.recv(sz) - except socket.error as e: - if (e.args[0] == errno.ECONNRESET and - (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): - # freebsd and Mach don't follow POSIX semantic of recv - # and fail with ECONNRESET if peer performed shutdown. - # See corresponding comment and code in TSocket::read() - # in lib/cpp/src/transport/TSocket.cpp. - self.close() - # Trigger the check to raise the END_OF_FILE exception below. - buff = '' - else: - raise - if len(buff) == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, - message='TSocket read 0 bytes') - return buff - - def write(self, buff): - if not self.handle: - raise TTransportException(type=TTransportException.NOT_OPEN, - message='Transport not open') - sent = 0 - have = len(buff) - while sent < have: - plus = self.handle.send(buff) - if plus == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, - message='TSocket sent 0 bytes') - sent += plus - buff = buff[plus:] - - def flush(self): - pass + if (e.args[0] == errno.ECONNRESET and + (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): + # freebsd and Mach don't follow POSIX semantic of recv + # and fail with ECONNRESET if peer performed shutdown. + # See corresponding comment and code in TSocket::read() + # in lib/cpp/src/transport/TSocket.cpp. + self.close() + # Trigger the check to raise the END_OF_FILE exception below. + buff = '' + else: + raise + if len(buff) == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') + return buff + + def write(self, buff): + if not self.handle: + raise TTransportException(type=TTransportException.NOT_OPEN, + message='Transport not open') + sent = 0 + have = len(buff) + while sent < have: + plus = self.handle.send(buff) + if plus == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket sent 0 bytes') + sent += plus + buff = buff[plus:] + + def flush(self): + pass class TServerSocket(TSocketBase, TServerTransportBase): - """Socket implementation of TServerTransport base.""" - - def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): - self.host = host - self.port = port - self._unix_socket = unix_socket - self._socket_family = socket_family - self.handle = None - - def listen(self): - res0 = self._resolveAddr() - socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family - for res in res0: - if res[0] is socket_family or res is res0[-1]: - break - - # We need remove the old unix socket if the file exists and - # nobody is listening on it. - if self._unix_socket: - tmp = socket.socket(res[0], res[1]) - try: - tmp.connect(res[4]) - except socket.error as err: - eno, message = err.args - if eno == errno.ECONNREFUSED: - os.unlink(res[4]) - - self.handle = socket.socket(res[0], res[1]) - self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if hasattr(self.handle, 'settimeout'): - self.handle.settimeout(None) - self.handle.bind(res[4]) - self.handle.listen(128) - - def accept(self): - client, addr = self.handle.accept() - result = TSocket() - result.setHandle(client) - return result + """Socket implementation of TServerTransport base.""" + + def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): + self.host = host + self.port = port + self._unix_socket = unix_socket + self._socket_family = socket_family + self.handle = None + + def listen(self): + res0 = self._resolveAddr() + socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family + for res in res0: + if res[0] is socket_family or res is res0[-1]: + break + + # We need remove the old unix socket if the file exists and + # nobody is listening on it. + if self._unix_socket: + tmp = socket.socket(res[0], res[1]) + try: + tmp.connect(res[4]) + except socket.error as err: + eno, message = err.args + if eno == errno.ECONNREFUSED: + os.unlink(res[4]) + + self.handle = socket.socket(res[0], res[1]) + self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(self.handle, 'settimeout'): + self.handle.settimeout(None) + self.handle.bind(res[4]) + self.handle.listen(128) + + def accept(self): + client, addr = self.handle.accept() + result = TSocket() + result.setHandle(client) + return result diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py index f99b3b9ba..6669891cd 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/transport/TTransport.py @@ -23,427 +23,426 @@ from ..compat import BufferIO class TTransportException(TException): - """Custom Transport Exception class""" + """Custom Transport Exception class""" - UNKNOWN = 0 - NOT_OPEN = 1 - ALREADY_OPEN = 2 - TIMED_OUT = 3 - END_OF_FILE = 4 - NEGATIVE_SIZE = 5 - SIZE_LIMIT = 6 + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + NEGATIVE_SIZE = 5 + SIZE_LIMIT = 6 - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type class TTransportBase(object): - """Base class for Thrift transport layer.""" + """Base class for Thrift transport layer.""" - def isOpen(self): - pass + def isOpen(self): + pass - def open(self): - pass + def open(self): + pass - def close(self): - pass + def close(self): + pass - def read(self, sz): - pass + def read(self, sz): + pass - def readAll(self, sz): - buff = b'' - have = 0 - while (have < sz): - chunk = self.read(sz - have) - have += len(chunk) - buff += chunk + def readAll(self, sz): + buff = b'' + have = 0 + while (have < sz): + chunk = self.read(sz - have) + have += len(chunk) + buff += chunk - if len(chunk) == 0: - raise EOFError() + if len(chunk) == 0: + raise EOFError() - return buff + return buff - def write(self, buf): - pass + def write(self, buf): + pass - def flush(self): - pass + def flush(self): + pass # This class should be thought of as an interface. class CReadableTransport(object): - """base class for transports that are readable from C""" + """base class for transports that are readable from C""" - # TODO(dreiss): Think about changing this interface to allow us to use - # a (Python, not c) StringIO instead, because it allows - # you to write after reading. + # TODO(dreiss): Think about changing this interface to allow us to use + # a (Python, not c) StringIO instead, because it allows + # you to write after reading. - # NOTE: This is a classic class, so properties will NOT work - # correctly for setting. - @property - def cstringio_buf(self): - """A cStringIO buffer that contains the current chunk we are reading.""" - pass + # NOTE: This is a classic class, so properties will NOT work + # correctly for setting. + @property + def cstringio_buf(self): + """A cStringIO buffer that contains the current chunk we are reading.""" + pass - def cstringio_refill(self, partialread, reqlen): - """Refills cstringio_buf. + def cstringio_refill(self, partialread, reqlen): + """Refills cstringio_buf. - Returns the currently used buffer (which can but need not be the same as - the old cstringio_buf). partialread is what the C code has read from the - buffer, and should be inserted into the buffer before any more reads. The - return value must be a new, not borrowed reference. Something along the - lines of self._buf should be fine. + Returns the currently used buffer (which can but need not be the same as + the old cstringio_buf). partialread is what the C code has read from the + buffer, and should be inserted into the buffer before any more reads. The + return value must be a new, not borrowed reference. Something along the + lines of self._buf should be fine. - If reqlen bytes can't be read, throw EOFError. - """ - pass + If reqlen bytes can't be read, throw EOFError. + """ + pass class TServerTransportBase(object): - """Base class for Thrift server transports.""" + """Base class for Thrift server transports.""" - def listen(self): - pass + def listen(self): + pass - def accept(self): - pass + def accept(self): + pass - def close(self): - pass + def close(self): + pass class TTransportFactoryBase(object): - """Base class for a Transport Factory""" + """Base class for a Transport Factory""" - def getTransport(self, trans): - return trans + def getTransport(self, trans): + return trans class TBufferedTransportFactory(object): - """Factory transport that builds buffered transports""" + """Factory transport that builds buffered transports""" - def getTransport(self, trans): - buffered = TBufferedTransport(trans) - return buffered + def getTransport(self, trans): + buffered = TBufferedTransport(trans) + return buffered class TBufferedTransport(TTransportBase, CReadableTransport): - """Class that wraps another transport and buffers its I/O. - - The implementation uses a (configurable) fixed-size read buffer - but buffers all writes until a flush is performed. - """ - DEFAULT_BUFFER = 4096 - - def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): - self.__trans = trans - self.__wbuf = BufferIO() - # Pass string argument to initialize read buffer as cStringIO.InputType - self.__rbuf = BufferIO(b'') - self.__rbuf_size = rbuf_size - - def isOpen(self): - return self.__trans.isOpen() - - def open(self): - return self.__trans.open() - - def close(self): - return self.__trans.close() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size))) - return self.__rbuf.read(sz) - - def write(self, buf): - try: - self.__wbuf.write(buf) - except Exception as e: - # on exception reset wbuf so it doesn't contain a partial function call - self.__wbuf = BufferIO() - raise e - self.__wbuf.getvalue() - - def flush(self): - out = self.__wbuf.getvalue() - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = BufferIO() - self.__trans.write(out) - self.__trans.flush() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, partialread, reqlen): - retstring = partialread - if reqlen < self.__rbuf_size: - # try to make a read of as much as we can. - retstring += self.__trans.read(self.__rbuf_size) - - # but make sure we do read reqlen bytes. - if len(retstring) < reqlen: - retstring += self.__trans.readAll(reqlen - len(retstring)) - - self.__rbuf = BufferIO(retstring) - return self.__rbuf + """Class that wraps another transport and buffers its I/O. + + The implementation uses a (configurable) fixed-size read buffer + but buffers all writes until a flush is performed. + """ + DEFAULT_BUFFER = 4096 + + def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): + self.__trans = trans + self.__wbuf = BufferIO() + # Pass string argument to initialize read buffer as cStringIO.InputType + self.__rbuf = BufferIO(b'') + self.__rbuf_size = rbuf_size + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size))) + return self.__rbuf.read(sz) + + def write(self, buf): + try: + self.__wbuf.write(buf) + except Exception as e: + # on exception reset wbuf so it doesn't contain a partial function call + self.__wbuf = BufferIO() + raise e + self.__wbuf.getvalue() + + def flush(self): + out = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = BufferIO() + self.__trans.write(out) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + retstring = partialread + if reqlen < self.__rbuf_size: + # try to make a read of as much as we can. + retstring += self.__trans.read(self.__rbuf_size) + + # but make sure we do read reqlen bytes. + if len(retstring) < reqlen: + retstring += self.__trans.readAll(reqlen - len(retstring)) + + self.__rbuf = BufferIO(retstring) + return self.__rbuf class TMemoryBuffer(TTransportBase, CReadableTransport): - """Wraps a cBytesIO object as a TTransport. + """Wraps a cBytesIO object as a TTransport. - NOTE: Unlike the C++ version of this class, you cannot write to it - then immediately read from it. If you want to read from a - TMemoryBuffer, you must either pass a string to the constructor. - TODO(dreiss): Make this work like the C++ version. - """ + NOTE: Unlike the C++ version of this class, you cannot write to it + then immediately read from it. If you want to read from a + TMemoryBuffer, you must either pass a string to the constructor. + TODO(dreiss): Make this work like the C++ version. + """ - def __init__(self, value=None): - """value -- a value to read from for stringio + def __init__(self, value=None): + """value -- a value to read from for stringio - If value is set, this will be a transport for reading, - otherwise, it is for writing""" - if value is not None: - self._buffer = BufferIO(value) - else: - self._buffer = BufferIO() + If value is set, this will be a transport for reading, + otherwise, it is for writing""" + if value is not None: + self._buffer = BufferIO(value) + else: + self._buffer = BufferIO() - def isOpen(self): - return not self._buffer.closed + def isOpen(self): + return not self._buffer.closed - def open(self): - pass + def open(self): + pass - def close(self): - self._buffer.close() + def close(self): + self._buffer.close() - def read(self, sz): - return self._buffer.read(sz) + def read(self, sz): + return self._buffer.read(sz) - def write(self, buf): - self._buffer.write(buf) + def write(self, buf): + self._buffer.write(buf) - def flush(self): - pass + def flush(self): + pass - def getvalue(self): - return self._buffer.getvalue() + def getvalue(self): + return self._buffer.getvalue() - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self._buffer + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._buffer - def cstringio_refill(self, partialread, reqlen): - # only one shot at reading... - raise EOFError() + def cstringio_refill(self, partialread, reqlen): + # only one shot at reading... + raise EOFError() class TFramedTransportFactory(object): - """Factory transport that builds framed transports""" + """Factory transport that builds framed transports""" - def getTransport(self, trans): - framed = TFramedTransport(trans) - return framed + def getTransport(self, trans): + framed = TFramedTransport(trans) + return framed class TFramedTransport(TTransportBase, CReadableTransport): - """Class that wraps another transport and frames its I/O when writing.""" - - def __init__(self, trans,): - self.__trans = trans - self.__rbuf = BufferIO(b'') - self.__wbuf = BufferIO() - - def isOpen(self): - return self.__trans.isOpen() - - def open(self): - return self.__trans.open() - - def close(self): - return self.__trans.close() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - - self.readFrame() - return self.__rbuf.read(sz) - - def readFrame(self): - buff = self.__trans.readAll(4) - sz, = unpack('!i', buff) - self.__rbuf = BufferIO(self.__trans.readAll(sz)) - - def write(self, buf): - self.__wbuf.write(buf) - - def flush(self): - wout = self.__wbuf.getvalue() - wsz = len(wout) - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = BufferIO() - # N.B.: Doing this string concatenation is WAY cheaper than making - # two separate calls to the underlying socket object. Socket writes in - # Python turn out to be REALLY expensive, but it seems to do a pretty - # good job of managing string buffer operations without excessive copies - buf = pack("!i", wsz) + wout - self.__trans.write(buf) - self.__trans.flush() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, prefix, reqlen): - # self.__rbuf will already be empty here because fastbinary doesn't - # ask for a refill until the previous buffer is empty. Therefore, - # we can start reading new frames immediately. - while len(prefix) < reqlen: - self.readFrame() - prefix += self.__rbuf.getvalue() - self.__rbuf = BufferIO(prefix) - return self.__rbuf + """Class that wraps another transport and frames its I/O when writing.""" + + def __init__(self, trans,): + self.__trans = trans + self.__rbuf = BufferIO(b'') + self.__wbuf = BufferIO() + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.readFrame() + return self.__rbuf.read(sz) + + def readFrame(self): + buff = self.__trans.readAll(4) + sz, = unpack('!i', buff) + self.__rbuf = BufferIO(self.__trans.readAll(sz)) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = BufferIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = pack("!i", wsz) + wout + self.__trans.write(buf) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self.readFrame() + prefix += self.__rbuf.getvalue() + self.__rbuf = BufferIO(prefix) + return self.__rbuf class TFileObjectTransport(TTransportBase): - """Wraps a file-like object to make it work as a Thrift transport.""" + """Wraps a file-like object to make it work as a Thrift transport.""" - def __init__(self, fileobj): - self.fileobj = fileobj + def __init__(self, fileobj): + self.fileobj = fileobj - def isOpen(self): - return True + def isOpen(self): + return True - def close(self): - self.fileobj.close() + def close(self): + self.fileobj.close() - def read(self, sz): - return self.fileobj.read(sz) + def read(self, sz): + return self.fileobj.read(sz) - def write(self, buf): - self.fileobj.write(buf) + def write(self, buf): + self.fileobj.write(buf) - def flush(self): - self.fileobj.flush() + def flush(self): + self.fileobj.flush() class TSaslClientTransport(TTransportBase, CReadableTransport): - """ - SASL transport - """ - - START = 1 - OK = 2 - BAD = 3 - ERROR = 4 - COMPLETE = 5 - - def __init__(self, transport, host, service, mechanism='GSSAPI', - **sasl_kwargs): """ - transport: an underlying transport to use, typically just a TSocket - host: the name of the server, from a SASL perspective - service: the name of the server's service, from a SASL perspective - mechanism: the name of the preferred mechanism to use - - All other kwargs will be passed to the puresasl.client.SASLClient - constructor. + SASL transport """ - from puresasl.client import SASLClient - - self.transport = transport - self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) - - self.__wbuf = BufferIO() - self.__rbuf = BufferIO(b'') - - def open(self): - if not self.transport.isOpen(): - self.transport.open() - - self.send_sasl_msg(self.START, self.sasl.mechanism) - self.send_sasl_msg(self.OK, self.sasl.process()) - - while True: - status, challenge = self.recv_sasl_msg() - if status == self.OK: - self.send_sasl_msg(self.OK, self.sasl.process(challenge)) - elif status == self.COMPLETE: - if not self.sasl.complete: - raise TTransportException("The server erroneously indicated " - "that SASL negotiation was complete") + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + def __init__(self, transport, host, service, mechanism='GSSAPI', + **sasl_kwargs): + """ + transport: an underlying transport to use, typically just a TSocket + host: the name of the server, from a SASL perspective + service: the name of the server's service, from a SASL perspective + mechanism: the name of the preferred mechanism to use + + All other kwargs will be passed to the puresasl.client.SASLClient + constructor. + """ + + from puresasl.client import SASLClient + + self.transport = transport + self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) + + self.__wbuf = BufferIO() + self.__rbuf = BufferIO(b'') + + def open(self): + if not self.transport.isOpen(): + self.transport.open() + + self.send_sasl_msg(self.START, self.sasl.mechanism) + self.send_sasl_msg(self.OK, self.sasl.process()) + + while True: + status, challenge = self.recv_sasl_msg() + if status == self.OK: + self.send_sasl_msg(self.OK, self.sasl.process(challenge)) + elif status == self.COMPLETE: + if not self.sasl.complete: + raise TTransportException("The server erroneously indicated " + "that SASL negotiation was complete") + else: + break + else: + raise TTransportException("Bad SASL negotiation status: %d (%s)" + % (status, challenge)) + + def send_sasl_msg(self, status, body): + header = pack(">BI", status, len(body)) + self.transport.write(header + body) + self.transport.flush() + + def recv_sasl_msg(self): + header = self.transport.readAll(5) + status, length = unpack(">BI", header) + if length > 0: + payload = self.transport.readAll(length) else: - break - else: - raise TTransportException("Bad SASL negotiation status: %d (%s)" - % (status, challenge)) - - def send_sasl_msg(self, status, body): - header = pack(">BI", status, len(body)) - self.transport.write(header + body) - self.transport.flush() - - def recv_sasl_msg(self): - header = self.transport.readAll(5) - status, length = unpack(">BI", header) - if length > 0: - payload = self.transport.readAll(length) - else: - payload = "" - return status, payload - - def write(self, data): - self.__wbuf.write(data) - - def flush(self): - data = self.__wbuf.getvalue() - encoded = self.sasl.wrap(data) - self.transport.write(''.join((pack("!i", len(encoded)), encoded))) - self.transport.flush() - self.__wbuf = BufferIO() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret - - self._read_frame() - return self.__rbuf.read(sz) - - def _read_frame(self): - header = self.transport.readAll(4) - length, = unpack('!i', header) - encoded = self.transport.readAll(length) - self.__rbuf = BufferIO(self.sasl.unwrap(encoded)) - - def close(self): - self.sasl.dispose() - self.transport.close() - - # based on TFramedTransport - @property - def cstringio_buf(self): - return self.__rbuf - - def cstringio_refill(self, prefix, reqlen): - # self.__rbuf will already be empty here because fastbinary doesn't - # ask for a refill until the previous buffer is empty. Therefore, - # we can start reading new frames immediately. - while len(prefix) < reqlen: - self._read_frame() - prefix += self.__rbuf.getvalue() - self.__rbuf = BufferIO(prefix) - return self.__rbuf - + payload = "" + return status, payload + + def write(self, data): + self.__wbuf.write(data) + + def flush(self): + data = self.__wbuf.getvalue() + encoded = self.sasl.wrap(data) + self.transport.write(''.join((pack("!i", len(encoded)), encoded))) + self.transport.flush() + self.__wbuf = BufferIO() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self._read_frame() + return self.__rbuf.read(sz) + + def _read_frame(self): + header = self.transport.readAll(4) + length, = unpack('!i', header) + encoded = self.transport.readAll(length) + self.__rbuf = BufferIO(self.sasl.unwrap(encoded)) + + def close(self): + self.sasl.dispose() + self.transport.close() + + # based on TFramedTransport + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self._read_frame() + prefix += self.__rbuf.getvalue() + self.__rbuf = BufferIO(prefix) + return self.__rbuf diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py index 6149a6c8e..5710b573d 100644 --- a/lib/py/src/transport/TTwisted.py +++ b/lib/py/src/transport/TTwisted.py @@ -120,7 +120,7 @@ class ThriftSASLClientProtocol(ThriftClientProtocol): MAX_LENGTH = 2 ** 31 - 1 def __init__(self, client_class, iprot_factory, oprot_factory=None, - host=None, service=None, mechanism='GSSAPI', **sasl_kwargs): + host=None, service=None, mechanism='GSSAPI', **sasl_kwargs): """ host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective @@ -236,7 +236,7 @@ class ThriftServerProtocol(basic.Int32StringReceiver): d = self.factory.processor.process(iprot, oprot) d.addCallbacks(self.processOk, self.processError, - callbackArgs=(tmo,)) + callbackArgs=(tmo,)) class IThriftServerFactory(Interface): @@ -288,7 +288,7 @@ class ThriftClientFactory(ClientFactory): def buildProtocol(self, addr): p = self.protocol(self.client_class, self.iprot_factory, - self.oprot_factory) + self.oprot_factory) p.factory = self return p @@ -298,7 +298,7 @@ class ThriftResource(resource.Resource): allowedMethods = ('POST',) def __init__(self, processor, inputProtocolFactory, - outputProtocolFactory=None): + outputProtocolFactory=None): resource.Resource.__init__(self) self.inputProtocolFactory = inputProtocolFactory if outputProtocolFactory is None: diff --git a/lib/py/src/transport/TZlibTransport.py b/lib/py/src/transport/TZlibTransport.py index 7fe5853ee..e84857924 100644 --- a/lib/py/src/transport/TZlibTransport.py +++ b/lib/py/src/transport/TZlibTransport.py @@ -29,220 +29,220 @@ from ..compat import BufferIO class TZlibTransportFactory(object): - """Factory transport that builds zlib compressed transports. - - This factory caches the last single client/transport that it was passed - and returns the same TZlibTransport object that was created. - - This caching means the TServer class will get the _same_ transport - object for both input and output transports from this factory. - (For non-threaded scenarios only, since the cache only holds one object) - - The purpose of this caching is to allocate only one TZlibTransport where - only one is really needed (since it must have separate read/write buffers), - and makes the statistics from getCompSavings() and getCompRatio() - easier to understand. - """ - # class scoped cache of last transport given and zlibtransport returned - _last_trans = None - _last_z = None - - def getTransport(self, trans, compresslevel=9): - """Wrap a transport, trans, with the TZlibTransport - compressed transport class, returning a new - transport to the caller. - - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Defaults to 9. - @type compresslevel: int - - This method returns a TZlibTransport which wraps the - passed C{trans} TTransport derived instance. - """ - if trans == self._last_trans: - return self._last_z - ztrans = TZlibTransport(trans, compresslevel) - self._last_trans = trans - self._last_z = ztrans - return ztrans + """Factory transport that builds zlib compressed transports. + This factory caches the last single client/transport that it was passed + and returns the same TZlibTransport object that was created. -class TZlibTransport(TTransportBase, CReadableTransport): - """Class that wraps a transport with zlib, compressing writes - and decompresses reads, using the python standard - library zlib module. - """ - # Read buffer size for the python fastbinary C extension, - # the TBinaryProtocolAccelerated class. - DEFAULT_BUFFSIZE = 4096 - - def __init__(self, trans, compresslevel=9): - """Create a new TZlibTransport, wrapping C{trans}, another - TTransport derived object. - - @param trans: A thrift transport object, i.e. a TSocket() object. - @type trans: TTransport - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Default is 9. - @type compresslevel: int - """ - self.__trans = trans - self.compresslevel = compresslevel - self.__rbuf = BufferIO() - self.__wbuf = BufferIO() - self._init_zlib() - self._init_stats() - - def _reinit_buffers(self): - """Internal method to initialize/reset the internal StringIO objects - for read and write buffers. - """ - self.__rbuf = BufferIO() - self.__wbuf = BufferIO() + This caching means the TServer class will get the _same_ transport + object for both input and output transports from this factory. + (For non-threaded scenarios only, since the cache only holds one object) - def _init_stats(self): - """Internal method to reset the internal statistics counters - for compression ratios and bandwidth savings. + The purpose of this caching is to allocate only one TZlibTransport where + only one is really needed (since it must have separate read/write buffers), + and makes the statistics from getCompSavings() and getCompRatio() + easier to understand. """ - self.bytes_in = 0 - self.bytes_out = 0 - self.bytes_in_comp = 0 - self.bytes_out_comp = 0 - - def _init_zlib(self): - """Internal method for setting up the zlib compression and - decompression objects. - """ - self._zcomp_read = zlib.decompressobj() - self._zcomp_write = zlib.compressobj(self.compresslevel) - - def getCompRatio(self): - """Get the current measured compression ratios (in,out) from - this transport. - - Returns a tuple of: - (inbound_compression_ratio, outbound_compression_ratio) - - The compression ratios are computed as: - compressed / uncompressed + # class scoped cache of last transport given and zlibtransport returned + _last_trans = None + _last_z = None + + def getTransport(self, trans, compresslevel=9): + """Wrap a transport, trans, with the TZlibTransport + compressed transport class, returning a new + transport to the caller. + + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Defaults to 9. + @type compresslevel: int + + This method returns a TZlibTransport which wraps the + passed C{trans} TTransport derived instance. + """ + if trans == self._last_trans: + return self._last_z + ztrans = TZlibTransport(trans, compresslevel) + self._last_trans = trans + self._last_z = ztrans + return ztrans - E.g., data that compresses by 10x will have a ratio of: 0.10 - and data that compresses to half of ts original size will - have a ratio of 0.5 - None is returned if no bytes have yet been processed in - a particular direction. - """ - r_percent, w_percent = (None, None) - if self.bytes_in > 0: - r_percent = self.bytes_in_comp / self.bytes_in - if self.bytes_out > 0: - w_percent = self.bytes_out_comp / self.bytes_out - return (r_percent, w_percent) - - def getCompSavings(self): - """Get the current count of saved bytes due to data - compression. - - Returns a tuple of: - (inbound_saved_bytes, outbound_saved_bytes) - - Note: if compression is actually expanding your - data (only likely with very tiny thrift objects), then - the values returned will be negative. - """ - r_saved = self.bytes_in - self.bytes_in_comp - w_saved = self.bytes_out - self.bytes_out_comp - return (r_saved, w_saved) - - def isOpen(self): - """Return the underlying transport's open status""" - return self.__trans.isOpen() - - def open(self): - """Open the underlying transport""" - self._init_stats() - return self.__trans.open() - - def listen(self): - """Invoke the underlying transport's listen() method""" - self.__trans.listen() - - def accept(self): - """Accept connections on the underlying transport""" - return self.__trans.accept() - - def close(self): - """Close the underlying transport,""" - self._reinit_buffers() - self._init_zlib() - return self.__trans.close() - - def read(self, sz): - """Read up to sz bytes from the decompressed bytes buffer, and - read from the underlying transport if the decompression - buffer is empty. - """ - ret = self.__rbuf.read(sz) - if len(ret) > 0: - return ret - # keep reading from transport until something comes back - while True: - if self.readComp(sz): - break - ret = self.__rbuf.read(sz) - return ret - - def readComp(self, sz): - """Read compressed data from the underlying transport, then - decompress it and append it to the internal StringIO read buffer - """ - zbuf = self.__trans.read(sz) - zbuf = self._zcomp_read.unconsumed_tail + zbuf - buf = self._zcomp_read.decompress(zbuf) - self.bytes_in += len(zbuf) - self.bytes_in_comp += len(buf) - old = self.__rbuf.read() - self.__rbuf = BufferIO(old + buf) - if len(old) + len(buf) == 0: - return False - return True - - def write(self, buf): - """Write some bytes, putting them into the internal write - buffer for eventual compression. - """ - self.__wbuf.write(buf) - - def flush(self): - """Flush any queued up data in the write buffer and ensure the - compression buffer is flushed out to the underlying transport +class TZlibTransport(TTransportBase, CReadableTransport): + """Class that wraps a transport with zlib, compressing writes + and decompresses reads, using the python standard + library zlib module. """ - wout = self.__wbuf.getvalue() - if len(wout) > 0: - zbuf = self._zcomp_write.compress(wout) - self.bytes_out += len(wout) - self.bytes_out_comp += len(zbuf) - else: - zbuf = '' - ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) - self.bytes_out_comp += len(ztail) - if (len(zbuf) + len(ztail)) > 0: - self.__wbuf = BufferIO() - self.__trans.write(zbuf + ztail) - self.__trans.flush() - - @property - def cstringio_buf(self): - """Implement the CReadableTransport interface""" - return self.__rbuf - - def cstringio_refill(self, partialread, reqlen): - """Implement the CReadableTransport interface for refill""" - retstring = partialread - if reqlen < self.DEFAULT_BUFFSIZE: - retstring += self.read(self.DEFAULT_BUFFSIZE) - while len(retstring) < reqlen: - retstring += self.read(reqlen - len(retstring)) - self.__rbuf = BufferIO(retstring) - return self.__rbuf + # Read buffer size for the python fastbinary C extension, + # the TBinaryProtocolAccelerated class. + DEFAULT_BUFFSIZE = 4096 + + def __init__(self, trans, compresslevel=9): + """Create a new TZlibTransport, wrapping C{trans}, another + TTransport derived object. + + @param trans: A thrift transport object, i.e. a TSocket() object. + @type trans: TTransport + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Default is 9. + @type compresslevel: int + """ + self.__trans = trans + self.compresslevel = compresslevel + self.__rbuf = BufferIO() + self.__wbuf = BufferIO() + self._init_zlib() + self._init_stats() + + def _reinit_buffers(self): + """Internal method to initialize/reset the internal StringIO objects + for read and write buffers. + """ + self.__rbuf = BufferIO() + self.__wbuf = BufferIO() + + def _init_stats(self): + """Internal method to reset the internal statistics counters + for compression ratios and bandwidth savings. + """ + self.bytes_in = 0 + self.bytes_out = 0 + self.bytes_in_comp = 0 + self.bytes_out_comp = 0 + + def _init_zlib(self): + """Internal method for setting up the zlib compression and + decompression objects. + """ + self._zcomp_read = zlib.decompressobj() + self._zcomp_write = zlib.compressobj(self.compresslevel) + + def getCompRatio(self): + """Get the current measured compression ratios (in,out) from + this transport. + + Returns a tuple of: + (inbound_compression_ratio, outbound_compression_ratio) + + The compression ratios are computed as: + compressed / uncompressed + + E.g., data that compresses by 10x will have a ratio of: 0.10 + and data that compresses to half of ts original size will + have a ratio of 0.5 + + None is returned if no bytes have yet been processed in + a particular direction. + """ + r_percent, w_percent = (None, None) + if self.bytes_in > 0: + r_percent = self.bytes_in_comp / self.bytes_in + if self.bytes_out > 0: + w_percent = self.bytes_out_comp / self.bytes_out + return (r_percent, w_percent) + + def getCompSavings(self): + """Get the current count of saved bytes due to data + compression. + + Returns a tuple of: + (inbound_saved_bytes, outbound_saved_bytes) + + Note: if compression is actually expanding your + data (only likely with very tiny thrift objects), then + the values returned will be negative. + """ + r_saved = self.bytes_in - self.bytes_in_comp + w_saved = self.bytes_out - self.bytes_out_comp + return (r_saved, w_saved) + + def isOpen(self): + """Return the underlying transport's open status""" + return self.__trans.isOpen() + + def open(self): + """Open the underlying transport""" + self._init_stats() + return self.__trans.open() + + def listen(self): + """Invoke the underlying transport's listen() method""" + self.__trans.listen() + + def accept(self): + """Accept connections on the underlying transport""" + return self.__trans.accept() + + def close(self): + """Close the underlying transport,""" + self._reinit_buffers() + self._init_zlib() + return self.__trans.close() + + def read(self, sz): + """Read up to sz bytes from the decompressed bytes buffer, and + read from the underlying transport if the decompression + buffer is empty. + """ + ret = self.__rbuf.read(sz) + if len(ret) > 0: + return ret + # keep reading from transport until something comes back + while True: + if self.readComp(sz): + break + ret = self.__rbuf.read(sz) + return ret + + def readComp(self, sz): + """Read compressed data from the underlying transport, then + decompress it and append it to the internal StringIO read buffer + """ + zbuf = self.__trans.read(sz) + zbuf = self._zcomp_read.unconsumed_tail + zbuf + buf = self._zcomp_read.decompress(zbuf) + self.bytes_in += len(zbuf) + self.bytes_in_comp += len(buf) + old = self.__rbuf.read() + self.__rbuf = BufferIO(old + buf) + if len(old) + len(buf) == 0: + return False + return True + + def write(self, buf): + """Write some bytes, putting them into the internal write + buffer for eventual compression. + """ + self.__wbuf.write(buf) + + def flush(self): + """Flush any queued up data in the write buffer and ensure the + compression buffer is flushed out to the underlying transport + """ + wout = self.__wbuf.getvalue() + if len(wout) > 0: + zbuf = self._zcomp_write.compress(wout) + self.bytes_out += len(wout) + self.bytes_out_comp += len(zbuf) + else: + zbuf = '' + ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) + self.bytes_out_comp += len(ztail) + if (len(zbuf) + len(ztail)) > 0: + self.__wbuf = BufferIO() + self.__trans.write(zbuf + ztail) + self.__trans.flush() + + @property + def cstringio_buf(self): + """Implement the CReadableTransport interface""" + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + """Implement the CReadableTransport interface for refill""" + retstring = partialread + if reqlen < self.DEFAULT_BUFFSIZE: + retstring += self.read(self.DEFAULT_BUFFSIZE) + while len(retstring) < reqlen: + retstring += self.read(reqlen - len(retstring)) + self.__rbuf = BufferIO(retstring) + return self.__rbuf diff --git a/lib/py/test/_import_local_thrift.py b/lib/py/test/_import_local_thrift.py index 30c1abcc0..174166969 100644 --- a/lib/py/test/_import_local_thrift.py +++ b/lib/py/test/_import_local_thrift.py @@ -6,8 +6,8 @@ SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) if sys.version_info[0] == 2: - import glob - libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0] - sys.path.insert(0, libdir) + import glob + libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0] + sys.path.insert(0, libdir) else: - sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib')) + sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib')) diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py index b7c3802fe..fa156a0fd 100644 --- a/lib/py/test/test_sslsocket.py +++ b/lib/py/test/test_sslsocket.py @@ -46,231 +46,231 @@ TEST_CIPHERS = 'DES-CBC3-SHA' class ServerAcceptor(threading.Thread): - def __init__(self, server): - super(ServerAcceptor, self).__init__() - self._server = server - self.client = None + def __init__(self, server): + super(ServerAcceptor, self).__init__() + self._server = server + self.client = None - def run(self): - self._server.listen() - self.client = self._server.accept() + def run(self): + self._server.listen() + self.client = self._server.accept() # Python 2.6 compat class AssertRaises(object): - def __init__(self, expected): - self._expected = expected + def __init__(self, expected): + self._expected = expected - def __enter__(self): - pass + def __enter__(self): + pass - def __exit__(self, exc_type, exc_value, traceback): - if not exc_type or not issubclass(exc_type, self._expected): - raise Exception('fail') - return True + def __exit__(self, exc_type, exc_value, traceback): + if not exc_type or not issubclass(exc_type, self._expected): + raise Exception('fail') + return True class TSSLSocketTest(unittest.TestCase): - def _assert_connection_failure(self, server, client): - try: - acc = ServerAcceptor(server) - acc.start() - time.sleep(CONNECT_DELAY) - client.setTimeout(CONNECT_TIMEOUT) - with self._assert_raises(Exception): - client.open() - select.select([], [client.handle], [], CONNECT_TIMEOUT) - # self.assertIsNone(acc.client) - self.assertTrue(acc.client is None) - finally: - server.close() - client.close() - - def _assert_raises(self, exc): - if sys.hexversion >= 0x020700F0: - return self.assertRaises(exc) - else: - return AssertRaises(exc) - - def _assert_connection_success(self, server, client): - try: - acc = ServerAcceptor(server) - acc.start() - time.sleep(0.15) - client.setTimeout(CONNECT_TIMEOUT) - client.open() - select.select([], [client.handle], [], CONNECT_TIMEOUT) - # self.assertIsNotNone(acc.client) - self.assertTrue(acc.client is not None) - finally: - server.close() - client.close() - - # deprecated feature - def test_deprecation(self): - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*') - TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT) - self.assertEqual(len(w), 1) - - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*') - # Deprecated signature - # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): - client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) - self.assertEqual(len(w), 7) - - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*') - # Deprecated signature - # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): - server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS) - self.assertEqual(len(w), 3) - - self._assert_connection_success(server, client) - - # deprecated feature - def test_set_cert_reqs_by_validate(self): - c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT) - self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) - - c1 = TSSLSocket('localhost', TEST_PORT, validate=False) - self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) - - # deprecated feature - def test_set_validate_by_cert_reqs(self): - c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE) - self.assertFalse(c1.validate) - - c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) - self.assertTrue(c2.validate) - - c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) - self.assertTrue(c3.validate) - - def test_unix_domain_socket(self): - if platform.system() == 'Windows': - print('skipping test_unix_domain_socket') - return - server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT) - client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE) - self._assert_connection_success(server, client) - - def test_server_cert(self): - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) - client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) - self._assert_connection_success(server, client) - - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) - # server cert on in ca_certs - client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) - self._assert_connection_failure(server, client) - - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) - client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE) - self._assert_connection_success(server, client) - - def test_set_server_cert(self): - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT) - with self._assert_raises(Exception): - server.certfile = 'foo' - with self._assert_raises(Exception): - server.certfile = None - server.certfile = SERVER_CERT - client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) - self._assert_connection_success(server, client) - - def test_client_cert(self): - server = TSSLServerSocket( - port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, - certfile=SERVER_CERT, ca_certs=CLIENT_CERT) - client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) - self._assert_connection_success(server, client) - - def test_ciphers(self): - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) - self._assert_connection_success(server, client) - - if not TSSLSocket._has_ciphers: - # unittest.skip is not available for Python 2.6 - print('skipping test_ciphers') - return - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL') - self._assert_connection_failure(server, client) - - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL') - self._assert_connection_failure(server, client) - - def test_ssl2_and_ssl3_disabled(self): - if not hasattr(ssl, 'PROTOCOL_SSLv3'): - print('PROTOCOL_SSLv3 is not available') - else: - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) - self._assert_connection_failure(server, client) - - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT) - self._assert_connection_failure(server, client) - - if not hasattr(ssl, 'PROTOCOL_SSLv2'): - print('PROTOCOL_SSLv2 is not available') - else: - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) - self._assert_connection_failure(server, client) - - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT) - self._assert_connection_failure(server, client) - - def test_newer_tls(self): - if not TSSLSocket._has_ssl_context: - # unittest.skip is not available for Python 2.6 - print('skipping test_newer_tls') - return - if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): - print('PROTOCOL_TLSv1_2 is not available') - else: - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) - self._assert_connection_success(server, client) - - if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): - print('PROTOCOL_TLSv1_1 is not available') - else: - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) - self._assert_connection_success(server, client) - - if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): - print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') - else: - server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) - client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) - self._assert_connection_failure(server, client) - - def test_ssl_context(self): - if not TSSLSocket._has_ssl_context: - # unittest.skip is not available for Python 2.6 - print('skipping test_ssl_context') - return - server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) - server_context.load_verify_locations(CLIENT_CERT) - - client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) - client_context.load_verify_locations(SERVER_CERT) - - server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context) - client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context) - self._assert_connection_success(server, client) + def _assert_connection_failure(self, server, client): + try: + acc = ServerAcceptor(server) + acc.start() + time.sleep(CONNECT_DELAY) + client.setTimeout(CONNECT_TIMEOUT) + with self._assert_raises(Exception): + client.open() + select.select([], [client.handle], [], CONNECT_TIMEOUT) + # self.assertIsNone(acc.client) + self.assertTrue(acc.client is None) + finally: + server.close() + client.close() + + def _assert_raises(self, exc): + if sys.hexversion >= 0x020700F0: + return self.assertRaises(exc) + else: + return AssertRaises(exc) + + def _assert_connection_success(self, server, client): + try: + acc = ServerAcceptor(server) + acc.start() + time.sleep(0.15) + client.setTimeout(CONNECT_TIMEOUT) + client.open() + select.select([], [client.handle], [], CONNECT_TIMEOUT) + # self.assertIsNotNone(acc.client) + self.assertTrue(acc.client is not None) + finally: + server.close() + client.close() + + # deprecated feature + def test_deprecation(self): + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*') + TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT) + self.assertEqual(len(w), 1) + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*') + # Deprecated signature + # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): + client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) + self.assertEqual(len(w), 7) + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*') + # Deprecated signature + # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): + server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS) + self.assertEqual(len(w), 3) + + self._assert_connection_success(server, client) + + # deprecated feature + def test_set_cert_reqs_by_validate(self): + c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT) + self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) + + c1 = TSSLSocket('localhost', TEST_PORT, validate=False) + self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) + + # deprecated feature + def test_set_validate_by_cert_reqs(self): + c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE) + self.assertFalse(c1.validate) + + c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + self.assertTrue(c2.validate) + + c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) + self.assertTrue(c3.validate) + + def test_unix_domain_socket(self): + if platform.system() == 'Windows': + print('skipping test_unix_domain_socket') + return + server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT) + client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE) + self._assert_connection_success(server, client) + + def test_server_cert(self): + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) + client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + self._assert_connection_success(server, client) + + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) + # server cert on in ca_certs + client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) + self._assert_connection_failure(server, client) + + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) + client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE) + self._assert_connection_success(server, client) + + def test_set_server_cert(self): + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT) + with self._assert_raises(Exception): + server.certfile = 'foo' + with self._assert_raises(Exception): + server.certfile = None + server.certfile = SERVER_CERT + client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + self._assert_connection_success(server, client) + + def test_client_cert(self): + server = TSSLServerSocket( + port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CERT) + client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) + self._assert_connection_success(server, client) + + def test_ciphers(self): + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) + self._assert_connection_success(server, client) + + if not TSSLSocket._has_ciphers: + # unittest.skip is not available for Python 2.6 + print('skipping test_ciphers') + return + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL') + self._assert_connection_failure(server, client) + + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL') + self._assert_connection_failure(server, client) + + def test_ssl2_and_ssl3_disabled(self): + if not hasattr(ssl, 'PROTOCOL_SSLv3'): + print('PROTOCOL_SSLv3 is not available') + else: + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) + self._assert_connection_failure(server, client) + + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT) + self._assert_connection_failure(server, client) + + if not hasattr(ssl, 'PROTOCOL_SSLv2'): + print('PROTOCOL_SSLv2 is not available') + else: + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) + self._assert_connection_failure(server, client) + + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT) + self._assert_connection_failure(server, client) + + def test_newer_tls(self): + if not TSSLSocket._has_ssl_context: + # unittest.skip is not available for Python 2.6 + print('skipping test_newer_tls') + return + if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): + print('PROTOCOL_TLSv1_2 is not available') + else: + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + self._assert_connection_success(server, client) + + if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): + print('PROTOCOL_TLSv1_1 is not available') + else: + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + self._assert_connection_success(server, client) + + if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): + print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') + else: + server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + self._assert_connection_failure(server, client) + + def test_ssl_context(self): + if not TSSLSocket._has_ssl_context: + # unittest.skip is not available for Python 2.6 + print('skipping test_ssl_context') + return + server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) + server_context.load_verify_locations(CLIENT_CERT) + + client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) + client_context.load_verify_locations(SERVER_CERT) + + server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context) + client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context) + self._assert_connection_success(server, client) if __name__ == '__main__': - # import logging - # logging.basicConfig(level=logging.DEBUG) - unittest.main() + # import logging + # logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/lib/py/test/thrift_json.py b/lib/py/test/thrift_json.py index e60aabacf..5ba7dd585 100644 --- a/lib/py/test/thrift_json.py +++ b/lib/py/test/thrift_json.py @@ -31,20 +31,20 @@ from thrift.transport import TTransport # mklink /D thrift ..\src # + class TestJSONString(unittest.TestCase): - def test_escaped_unicode_string(self): - unicode_json = b'"hello \\u0e01\\u0e02\\u0e03\\ud835\\udcab\\udb40\\udc70 unicode"' - unicode_text = u'hello \u0e01\u0e02\u0e03\U0001D4AB\U000E0070 unicode' + def test_escaped_unicode_string(self): + unicode_json = b'"hello \\u0e01\\u0e02\\u0e03\\ud835\\udcab\\udb40\\udc70 unicode"' + unicode_text = u'hello \u0e01\u0e02\u0e03\U0001D4AB\U000E0070 unicode' - buf = TTransport.TMemoryBuffer(unicode_json) - transport = TTransport.TBufferedTransportFactory().getTransport(buf) - protocol = TJSONProtocol(transport) + buf = TTransport.TMemoryBuffer(unicode_json) + transport = TTransport.TBufferedTransportFactory().getTransport(buf) + protocol = TJSONProtocol(transport) - if sys.version_info[0] == 2: - unicode_text = unicode_text.encode('utf8') - self.assertEqual(protocol.readString(), unicode_text) + if sys.version_info[0] == 2: + unicode_text = unicode_text.encode('utf8') + self.assertEqual(protocol.readString(), unicode_text) if __name__ == '__main__': - unittest.main() - + unittest.main() diff --git a/test/crossrunner/collect.py b/test/crossrunner/collect.py index f92b9e2d7..e91ac0b43 100644 --- a/test/crossrunner/collect.py +++ b/test/crossrunner/collect.py @@ -40,13 +40,13 @@ from .util import merge_dict # (e.g. "binary" is equivalent to "binary:binary" in tests.json) # VALID_JSON_KEYS = [ - 'name', # name of the library, typically a language name - 'workdir', # work directory where command is executed - 'command', # test command - 'extra_args', # args appended to command after other args are appended - 'remote_args', # args added to the other side of the program - 'join_args', # whether args should be passed as single concatenated string - 'env', # additional environmental variable + 'name', # name of the library, typically a language name + 'workdir', # work directory where command is executed + 'command', # test command + 'extra_args', # args appended to command after other args are appended + 'remote_args', # args added to the other side of the program + 'join_args', # whether args should be passed as single concatenated string + 'env', # additional environmental variable ] DEFAULT_DELAY = 1 @@ -54,102 +54,102 @@ DEFAULT_TIMEOUT = 5 def _collect_testlibs(config, server_match, client_match=[None]): - """Collects server/client configurations from library configurations""" - def expand_libs(config): - for lib in config: - sv = lib.pop('server', None) - cl = lib.pop('client', None) - yield lib, sv, cl - - def yield_testlibs(base_configs, configs, match): - for base, conf in zip(base_configs, configs): - if conf: - if not match or base['name'] in match: - platforms = conf.get('platforms') or base.get('platforms') - if not platforms or platform.system() in platforms: - yield merge_dict(base, conf) - - libs, svs, cls = zip(*expand_libs(config)) - servers = list(yield_testlibs(libs, svs, server_match)) - clients = list(yield_testlibs(libs, cls, client_match)) - return servers, clients + """Collects server/client configurations from library configurations""" + def expand_libs(config): + for lib in config: + sv = lib.pop('server', None) + cl = lib.pop('client', None) + yield lib, sv, cl + + def yield_testlibs(base_configs, configs, match): + for base, conf in zip(base_configs, configs): + if conf: + if not match or base['name'] in match: + platforms = conf.get('platforms') or base.get('platforms') + if not platforms or platform.system() in platforms: + yield merge_dict(base, conf) + + libs, svs, cls = zip(*expand_libs(config)) + servers = list(yield_testlibs(libs, svs, server_match)) + clients = list(yield_testlibs(libs, cls, client_match)) + return servers, clients def collect_features(config, match): - res = list(map(re.compile, match)) - return list(filter(lambda c: any(map(lambda r: r.search(c['name']), res)), config)) + res = list(map(re.compile, match)) + return list(filter(lambda c: any(map(lambda r: r.search(c['name']), res)), config)) def _do_collect_tests(servers, clients): - def intersection(key, o1, o2): - """intersection of two collections. - collections are replaced with sets the first time""" - def cached_set(o, key): - v = o[key] - if not isinstance(v, set): - v = set(v) - o[key] = v - return v - return cached_set(o1, key) & cached_set(o2, key) - - def intersect_with_spec(key, o1, o2): - # store as set of (spec, impl) tuple - def cached_set(o): - def to_spec_impl_tuples(values): - for v in values: - spec, _, impl = v.partition(':') - yield spec, impl or spec - v = o[key] - if not isinstance(v, set): - v = set(to_spec_impl_tuples(set(v))) - o[key] = v - return v - for spec1, impl1 in cached_set(o1): - for spec2, impl2 in cached_set(o2): - if spec1 == spec2: - name = impl1 if impl1 == impl2 else '%s-%s' % (impl1, impl2) - yield name, impl1, impl2 - - def maybe_max(key, o1, o2, default): - """maximum of two if present, otherwise defult value""" - v1 = o1.get(key) - v2 = o2.get(key) - return max(v1, v2) if v1 and v2 else v1 or v2 or default - - def filter_with_validkeys(o): - ret = {} - for key in VALID_JSON_KEYS: - if key in o: - ret[key] = o[key] - return ret - - def merge_metadata(o, **ret): - for key in VALID_JSON_KEYS: - if key in o: - ret[key] = o[key] - return ret - - for sv, cl in product(servers, clients): - for proto, proto1, proto2 in intersect_with_spec('protocols', sv, cl): - for trans, trans1, trans2 in intersect_with_spec('transports', sv, cl): - for sock in intersection('sockets', sv, cl): - yield { - 'server': merge_metadata(sv, **{'protocol': proto1, 'transport': trans1}), - 'client': merge_metadata(cl, **{'protocol': proto2, 'transport': trans2}), - 'delay': maybe_max('delay', sv, cl, DEFAULT_DELAY), - 'timeout': maybe_max('timeout', sv, cl, DEFAULT_TIMEOUT), - 'protocol': proto, - 'transport': trans, - 'socket': sock - } + def intersection(key, o1, o2): + """intersection of two collections. + collections are replaced with sets the first time""" + def cached_set(o, key): + v = o[key] + if not isinstance(v, set): + v = set(v) + o[key] = v + return v + return cached_set(o1, key) & cached_set(o2, key) + + def intersect_with_spec(key, o1, o2): + # store as set of (spec, impl) tuple + def cached_set(o): + def to_spec_impl_tuples(values): + for v in values: + spec, _, impl = v.partition(':') + yield spec, impl or spec + v = o[key] + if not isinstance(v, set): + v = set(to_spec_impl_tuples(set(v))) + o[key] = v + return v + for spec1, impl1 in cached_set(o1): + for spec2, impl2 in cached_set(o2): + if spec1 == spec2: + name = impl1 if impl1 == impl2 else '%s-%s' % (impl1, impl2) + yield name, impl1, impl2 + + def maybe_max(key, o1, o2, default): + """maximum of two if present, otherwise defult value""" + v1 = o1.get(key) + v2 = o2.get(key) + return max(v1, v2) if v1 and v2 else v1 or v2 or default + + def filter_with_validkeys(o): + ret = {} + for key in VALID_JSON_KEYS: + if key in o: + ret[key] = o[key] + return ret + + def merge_metadata(o, **ret): + for key in VALID_JSON_KEYS: + if key in o: + ret[key] = o[key] + return ret + + for sv, cl in product(servers, clients): + for proto, proto1, proto2 in intersect_with_spec('protocols', sv, cl): + for trans, trans1, trans2 in intersect_with_spec('transports', sv, cl): + for sock in intersection('sockets', sv, cl): + yield { + 'server': merge_metadata(sv, **{'protocol': proto1, 'transport': trans1}), + 'client': merge_metadata(cl, **{'protocol': proto2, 'transport': trans2}), + 'delay': maybe_max('delay', sv, cl, DEFAULT_DELAY), + 'timeout': maybe_max('timeout', sv, cl, DEFAULT_TIMEOUT), + 'protocol': proto, + 'transport': trans, + 'socket': sock + } def collect_cross_tests(tests_dict, server_match, client_match): - sv, cl = _collect_testlibs(tests_dict, server_match, client_match) - return list(_do_collect_tests(sv, cl)) + sv, cl = _collect_testlibs(tests_dict, server_match, client_match) + return list(_do_collect_tests(sv, cl)) def collect_feature_tests(tests_dict, features_dict, server_match, feature_match): - sv, _ = _collect_testlibs(tests_dict, server_match) - ft = collect_features(features_dict, feature_match) - return list(_do_collect_tests(sv, ft)) + sv, _ = _collect_testlibs(tests_dict, server_match) + ft = collect_features(features_dict, feature_match) + return list(_do_collect_tests(sv, ft)) diff --git a/test/crossrunner/compat.py b/test/crossrunner/compat.py index 6ab9d713b..f1ca91bb3 100644 --- a/test/crossrunner/compat.py +++ b/test/crossrunner/compat.py @@ -2,23 +2,23 @@ import os import sys if sys.version_info[0] == 2: - _ENCODE = sys.getfilesystemencoding() + _ENCODE = sys.getfilesystemencoding() - def path_join(*args): - bin_args = map(lambda a: a.decode(_ENCODE), args) - return os.path.join(*bin_args).encode(_ENCODE) + def path_join(*args): + bin_args = map(lambda a: a.decode(_ENCODE), args) + return os.path.join(*bin_args).encode(_ENCODE) - def str_join(s, l): - bin_args = map(lambda a: a.decode(_ENCODE), l) - b = s.decode(_ENCODE) - return b.join(bin_args).encode(_ENCODE) + def str_join(s, l): + bin_args = map(lambda a: a.decode(_ENCODE), l) + b = s.decode(_ENCODE) + return b.join(bin_args).encode(_ENCODE) - logfile_open = open + logfile_open = open else: - path_join = os.path.join - str_join = str.join + path_join = os.path.join + str_join = str.join - def logfile_open(*args): - return open(*args, errors='replace') + def logfile_open(*args): + return open(*args, errors='replace') diff --git a/test/crossrunner/report.py b/test/crossrunner/report.py index be7271cb1..cc5f26fe2 100644 --- a/test/crossrunner/report.py +++ b/test/crossrunner/report.py @@ -39,396 +39,396 @@ FAIL_JSON = 'known_failures_%s.json' def generate_known_failures(testdir, overwrite, save, out): - def collect_failures(results): - success_index = 5 - for r in results: - if not r[success_index]: - yield TestEntry.get_name(*r) - try: - with logfile_open(path_join(testdir, RESULT_JSON), 'r') as fp: - results = json.load(fp) - except IOError: - sys.stderr.write('Unable to load last result. Did you run tests ?\n') - return False - fails = collect_failures(results['results']) - if not overwrite: - known = load_known_failures(testdir) - known.extend(fails) - fails = known - fails_json = json.dumps(sorted(set(fails)), indent=2, separators=(',', ': ')) - if save: - with logfile_open(os.path.join(testdir, FAIL_JSON % platform.system()), 'w+') as fp: - fp.write(fails_json) - sys.stdout.write('Successfully updated known failures.\n') - if out: - sys.stdout.write(fails_json) - sys.stdout.write('\n') - return True + def collect_failures(results): + success_index = 5 + for r in results: + if not r[success_index]: + yield TestEntry.get_name(*r) + try: + with logfile_open(path_join(testdir, RESULT_JSON), 'r') as fp: + results = json.load(fp) + except IOError: + sys.stderr.write('Unable to load last result. Did you run tests ?\n') + return False + fails = collect_failures(results['results']) + if not overwrite: + known = load_known_failures(testdir) + known.extend(fails) + fails = known + fails_json = json.dumps(sorted(set(fails)), indent=2, separators=(',', ': ')) + if save: + with logfile_open(os.path.join(testdir, FAIL_JSON % platform.system()), 'w+') as fp: + fp.write(fails_json) + sys.stdout.write('Successfully updated known failures.\n') + if out: + sys.stdout.write(fails_json) + sys.stdout.write('\n') + return True def load_known_failures(testdir): - try: - with logfile_open(path_join(testdir, FAIL_JSON % platform.system()), 'r') as fp: - return json.load(fp) - except IOError: - return [] + try: + with logfile_open(path_join(testdir, FAIL_JSON % platform.system()), 'r') as fp: + return json.load(fp) + except IOError: + return [] class TestReporter(object): - # Unfortunately, standard library doesn't handle timezone well - # DATETIME_FORMAT = '%a %b %d %H:%M:%S %Z %Y' - DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y' + # Unfortunately, standard library doesn't handle timezone well + # DATETIME_FORMAT = '%a %b %d %H:%M:%S %Z %Y' + DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y' - def __init__(self): - self._log = multiprocessing.get_logger() - self._lock = multiprocessing.Lock() + def __init__(self): + self._log = multiprocessing.get_logger() + self._lock = multiprocessing.Lock() - @classmethod - def test_logfile(cls, test_name, prog_kind, dir=None): - relpath = path_join('log', '%s_%s.log' % (test_name, prog_kind)) - return relpath if not dir else os.path.realpath(path_join(dir, relpath)) + @classmethod + def test_logfile(cls, test_name, prog_kind, dir=None): + relpath = path_join('log', '%s_%s.log' % (test_name, prog_kind)) + return relpath if not dir else os.path.realpath(path_join(dir, relpath)) - def _start(self): - self._start_time = time.time() + def _start(self): + self._start_time = time.time() - @property - def _elapsed(self): - return time.time() - self._start_time + @property + def _elapsed(self): + return time.time() - self._start_time - @classmethod - def _format_date(cls): - return '%s' % datetime.datetime.now().strftime(cls.DATETIME_FORMAT) + @classmethod + def _format_date(cls): + return '%s' % datetime.datetime.now().strftime(cls.DATETIME_FORMAT) - def _print_date(self): - print(self._format_date(), file=self.out) + def _print_date(self): + print(self._format_date(), file=self.out) - def _print_bar(self, out=None): - print( - '==========================================================================', - file=(out or self.out)) + def _print_bar(self, out=None): + print( + '==========================================================================', + file=(out or self.out)) - def _print_exec_time(self): - print('Test execution took {:.1f} seconds.'.format(self._elapsed), file=self.out) + def _print_exec_time(self): + print('Test execution took {:.1f} seconds.'.format(self._elapsed), file=self.out) class ExecReporter(TestReporter): - def __init__(self, testdir, test, prog): - super(ExecReporter, self).__init__() - self._test = test - self._prog = prog - self.logpath = self.test_logfile(test.name, prog.kind, testdir) - self.out = None - - def begin(self): - self._start() - self._open() - if self.out and not self.out.closed: - self._print_header() - else: - self._log.debug('Output stream is not available.') - - def end(self, returncode): - self._lock.acquire() - try: - if self.out and not self.out.closed: - self._print_footer(returncode) - self._close() + def __init__(self, testdir, test, prog): + super(ExecReporter, self).__init__() + self._test = test + self._prog = prog + self.logpath = self.test_logfile(test.name, prog.kind, testdir) self.out = None - else: - self._log.debug('Output stream is not available.') - finally: - self._lock.release() - - def killed(self): - print(file=self.out) - print('Server process is successfully killed.', file=self.out) - self.end(None) - - def died(self): - print(file=self.out) - print('*** Server process has died unexpectedly ***', file=self.out) - self.end(None) - - _init_failure_exprs = { - 'server': list(map(re.compile, [ - '[Aa]ddress already in use', - 'Could not bind', - 'EADDRINUSE', - ])), - 'client': list(map(re.compile, [ - '[Cc]onnection refused', - 'Could not connect to localhost', - 'ECONNREFUSED', - 'No such file or directory', # domain socket - ])), - } - - def maybe_false_positive(self): - """Searches through log file for socket bind error. - Returns True if suspicious expression is found, otherwise False""" - try: - if self.out and not self.out.closed: + + def begin(self): + self._start() + self._open() + if self.out and not self.out.closed: + self._print_header() + else: + self._log.debug('Output stream is not available.') + + def end(self, returncode): + self._lock.acquire() + try: + if self.out and not self.out.closed: + self._print_footer(returncode) + self._close() + self.out = None + else: + self._log.debug('Output stream is not available.') + finally: + self._lock.release() + + def killed(self): + print(file=self.out) + print('Server process is successfully killed.', file=self.out) + self.end(None) + + def died(self): + print(file=self.out) + print('*** Server process has died unexpectedly ***', file=self.out) + self.end(None) + + _init_failure_exprs = { + 'server': list(map(re.compile, [ + '[Aa]ddress already in use', + 'Could not bind', + 'EADDRINUSE', + ])), + 'client': list(map(re.compile, [ + '[Cc]onnection refused', + 'Could not connect to localhost', + 'ECONNREFUSED', + 'No such file or directory', # domain socket + ])), + } + + def maybe_false_positive(self): + """Searches through log file for socket bind error. + Returns True if suspicious expression is found, otherwise False""" + try: + if self.out and not self.out.closed: + self.out.flush() + exprs = self._init_failure_exprs[self._prog.kind] + + def match(line): + for expr in exprs: + if expr.search(line): + return True + + with logfile_open(self.logpath, 'r') as fp: + if any(map(match, fp)): + return True + except (KeyboardInterrupt, SystemExit): + raise + except Exception as ex: + self._log.warn('[%s]: Error while detecting false positive: %s' % (self._test.name, str(ex))) + self._log.info(traceback.print_exc()) + return False + + def _open(self): + self.out = logfile_open(self.logpath, 'w+') + + def _close(self): + self.out.close() + + def _print_header(self): + self._print_date() + print('Executing: %s' % str_join(' ', self._prog.command), file=self.out) + print('Directory: %s' % self._prog.workdir, file=self.out) + print('config:delay: %s' % self._test.delay, file=self.out) + print('config:timeout: %s' % self._test.timeout, file=self.out) + self._print_bar() self.out.flush() - exprs = self._init_failure_exprs[self._prog.kind] - - def match(line): - for expr in exprs: - if expr.search(line): - return True - - with logfile_open(self.logpath, 'r') as fp: - if any(map(match, fp)): - return True - except (KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - self._log.warn('[%s]: Error while detecting false positive: %s' % (self._test.name, str(ex))) - self._log.info(traceback.print_exc()) - return False - - def _open(self): - self.out = logfile_open(self.logpath, 'w+') - - def _close(self): - self.out.close() - - def _print_header(self): - self._print_date() - print('Executing: %s' % str_join(' ', self._prog.command), file=self.out) - print('Directory: %s' % self._prog.workdir, file=self.out) - print('config:delay: %s' % self._test.delay, file=self.out) - print('config:timeout: %s' % self._test.timeout, file=self.out) - self._print_bar() - self.out.flush() - - def _print_footer(self, returncode=None): - self._print_bar() - if returncode is not None: - print('Return code: %d' % returncode, file=self.out) - else: - print('Process is killed.', file=self.out) - self._print_exec_time() - self._print_date() + def _print_footer(self, returncode=None): + self._print_bar() + if returncode is not None: + print('Return code: %d' % returncode, file=self.out) + else: + print('Process is killed.', file=self.out) + self._print_exec_time() + self._print_date() -class SummaryReporter(TestReporter): - def __init__(self, basedir, testdir_relative, concurrent=True): - super(SummaryReporter, self).__init__() - self._basedir = basedir - self._testdir_rel = testdir_relative - self.logdir = path_join(self.testdir, LOG_DIR) - self.out_path = path_join(self.testdir, RESULT_JSON) - self.concurrent = concurrent - self.out = sys.stdout - self._platform = platform.system() - self._revision = self._get_revision() - self._tests = [] - if not os.path.exists(self.logdir): - os.mkdir(self.logdir) - self._known_failures = load_known_failures(self.testdir) - self._unexpected_success = [] - self._flaky_success = [] - self._unexpected_failure = [] - self._expected_failure = [] - self._print_header() - - @property - def testdir(self): - return path_join(self._basedir, self._testdir_rel) - - def _result_string(self, test): - if test.success: - if test.retry_count == 0: - return 'success' - elif test.retry_count == 1: - return 'flaky(1 retry)' - else: - return 'flaky(%d retries)' % test.retry_count - elif test.expired: - return 'failure(timeout)' - else: - return 'failure(%d)' % test.returncode - - def _get_revision(self): - p = subprocess.Popen(['git', 'rev-parse', '--short', 'HEAD'], - cwd=self.testdir, stdout=subprocess.PIPE) - out, _ = p.communicate() - return out.strip() - - def _format_test(self, test, with_result=True): - name = '%s-%s' % (test.server.name, test.client.name) - trans = '%s-%s' % (test.transport, test.socket) - if not with_result: - return '{:24s}{:13s}{:25s}'.format(name[:23], test.protocol[:12], trans[:24]) - else: - return '{:24s}{:13s}{:25s}{:s}\n'.format(name[:23], test.protocol[:12], trans[:24], self._result_string(test)) - - def _print_test_header(self): - self._print_bar() - print( - '{:24s}{:13s}{:25s}{:s}'.format('server-client:', 'protocol:', 'transport:', 'result:'), - file=self.out) - - def _print_header(self): - self._start() - print('Apache Thrift - Integration Test Suite', file=self.out) - self._print_date() - self._print_test_header() - - def _print_unexpected_failure(self): - if len(self._unexpected_failure) > 0: - self.out.writelines([ - '*** Following %d failures were unexpected ***:\n' % len(self._unexpected_failure), - 'If it is introduced by you, please fix it before submitting the code.\n', - # 'If not, please report at https://issues.apache.org/jira/browse/THRIFT\n', - ]) - self._print_test_header() - for i in self._unexpected_failure: - self.out.write(self._format_test(self._tests[i])) - self._print_bar() - else: - print('No unexpected failures.', file=self.out) - - def _print_flaky_success(self): - if len(self._flaky_success) > 0: - print( - 'Following %d tests were expected to cleanly succeed but needed retry:' % len(self._flaky_success), - file=self.out) - self._print_test_header() - for i in self._flaky_success: - self.out.write(self._format_test(self._tests[i])) - self._print_bar() - - def _print_unexpected_success(self): - if len(self._unexpected_success) > 0: - print( - 'Following %d tests were known to fail but succeeded (maybe flaky):' % len(self._unexpected_success), - file=self.out) - self._print_test_header() - for i in self._unexpected_success: - self.out.write(self._format_test(self._tests[i])) - self._print_bar() - - def _http_server_command(self, port): - if sys.version_info[0] < 3: - return 'python -m SimpleHTTPServer %d' % port - else: - return 'python -m http.server %d' % port - - def _print_footer(self): - fail_count = len(self._expected_failure) + len(self._unexpected_failure) - self._print_bar() - self._print_unexpected_success() - self._print_flaky_success() - self._print_unexpected_failure() - self._write_html_data() - self._assemble_log('unexpected failures', self._unexpected_failure) - self._assemble_log('known failures', self._expected_failure) - self.out.writelines([ - 'You can browse results at:\n', - '\tfile://%s/%s\n' % (self.testdir, RESULT_HTML), - '# If you use Chrome, run:\n', - '# \tcd %s\n#\t%s\n' % (self._basedir, self._http_server_command(8001)), - '# then browse:\n', - '# \thttp://localhost:%d/%s/\n' % (8001, self._testdir_rel), - 'Full log for each test is here:\n', - '\ttest/log/client_server_protocol_transport_client.log\n', - '\ttest/log/client_server_protocol_transport_server.log\n', - '%d failed of %d tests in total.\n' % (fail_count, len(self._tests)), - ]) - self._print_exec_time() - self._print_date() - - def _render_result(self, test): - return [ - test.server.name, - test.client.name, - test.protocol, - test.transport, - test.socket, - test.success, - test.as_expected, - test.returncode, - { - 'server': self.test_logfile(test.name, test.server.kind), - 'client': self.test_logfile(test.name, test.client.kind), - }, - ] - - def _write_html_data(self): - """Writes JSON data to be read by result html""" - results = [self._render_result(r) for r in self._tests] - with logfile_open(self.out_path, 'w+') as fp: - fp.write(json.dumps({ - 'date': self._format_date(), - 'revision': str(self._revision), - 'platform': self._platform, - 'duration': '{:.1f}'.format(self._elapsed), - 'results': results, - }, indent=2)) - - def _assemble_log(self, title, indexes): - if len(indexes) > 0: - def add_prog_log(fp, test, prog_kind): - print('*************************** %s message ***************************' % prog_kind, - file=fp) - path = self.test_logfile(test.name, prog_kind, self.testdir) - if os.path.exists(path): - with logfile_open(path, 'r') as prog_fp: - print(prog_fp.read(), file=fp) - filename = title.replace(' ', '_') + '.log' - with logfile_open(os.path.join(self.logdir, filename), 'w+') as fp: - for test in map(self._tests.__getitem__, indexes): - fp.write('TEST: [%s]\n' % test.name) - add_prog_log(fp, test, test.server.kind) - add_prog_log(fp, test, test.client.kind) - fp.write('**********************************************************************\n\n') - print('%s are logged to %s/%s/%s' % (title.capitalize(), self._testdir_rel, LOG_DIR, filename)) - - def end(self): - self._print_footer() - return len(self._unexpected_failure) == 0 - - def add_test(self, test_dict): - test = TestEntry(self.testdir, **test_dict) - self._lock.acquire() - try: - if not self.concurrent: - self.out.write(self._format_test(test, False)) - self.out.flush() - self._tests.append(test) - return len(self._tests) - 1 - finally: - self._lock.release() - def add_result(self, index, returncode, expired, retry_count): - self._lock.acquire() - try: - failed = returncode is None or returncode != 0 - flaky = not failed and retry_count != 0 - test = self._tests[index] - known = test.name in self._known_failures - if failed: - if known: - self._log.debug('%s failed as expected' % test.name) - self._expected_failure.append(index) +class SummaryReporter(TestReporter): + def __init__(self, basedir, testdir_relative, concurrent=True): + super(SummaryReporter, self).__init__() + self._basedir = basedir + self._testdir_rel = testdir_relative + self.logdir = path_join(self.testdir, LOG_DIR) + self.out_path = path_join(self.testdir, RESULT_JSON) + self.concurrent = concurrent + self.out = sys.stdout + self._platform = platform.system() + self._revision = self._get_revision() + self._tests = [] + if not os.path.exists(self.logdir): + os.mkdir(self.logdir) + self._known_failures = load_known_failures(self.testdir) + self._unexpected_success = [] + self._flaky_success = [] + self._unexpected_failure = [] + self._expected_failure = [] + self._print_header() + + @property + def testdir(self): + return path_join(self._basedir, self._testdir_rel) + + def _result_string(self, test): + if test.success: + if test.retry_count == 0: + return 'success' + elif test.retry_count == 1: + return 'flaky(1 retry)' + else: + return 'flaky(%d retries)' % test.retry_count + elif test.expired: + return 'failure(timeout)' + else: + return 'failure(%d)' % test.returncode + + def _get_revision(self): + p = subprocess.Popen(['git', 'rev-parse', '--short', 'HEAD'], + cwd=self.testdir, stdout=subprocess.PIPE) + out, _ = p.communicate() + return out.strip() + + def _format_test(self, test, with_result=True): + name = '%s-%s' % (test.server.name, test.client.name) + trans = '%s-%s' % (test.transport, test.socket) + if not with_result: + return '{:24s}{:13s}{:25s}'.format(name[:23], test.protocol[:12], trans[:24]) + else: + return '{:24s}{:13s}{:25s}{:s}\n'.format(name[:23], test.protocol[:12], trans[:24], self._result_string(test)) + + def _print_test_header(self): + self._print_bar() + print( + '{:24s}{:13s}{:25s}{:s}'.format('server-client:', 'protocol:', 'transport:', 'result:'), + file=self.out) + + def _print_header(self): + self._start() + print('Apache Thrift - Integration Test Suite', file=self.out) + self._print_date() + self._print_test_header() + + def _print_unexpected_failure(self): + if len(self._unexpected_failure) > 0: + self.out.writelines([ + '*** Following %d failures were unexpected ***:\n' % len(self._unexpected_failure), + 'If it is introduced by you, please fix it before submitting the code.\n', + # 'If not, please report at https://issues.apache.org/jira/browse/THRIFT\n', + ]) + self._print_test_header() + for i in self._unexpected_failure: + self.out.write(self._format_test(self._tests[i])) + self._print_bar() + else: + print('No unexpected failures.', file=self.out) + + def _print_flaky_success(self): + if len(self._flaky_success) > 0: + print( + 'Following %d tests were expected to cleanly succeed but needed retry:' % len(self._flaky_success), + file=self.out) + self._print_test_header() + for i in self._flaky_success: + self.out.write(self._format_test(self._tests[i])) + self._print_bar() + + def _print_unexpected_success(self): + if len(self._unexpected_success) > 0: + print( + 'Following %d tests were known to fail but succeeded (maybe flaky):' % len(self._unexpected_success), + file=self.out) + self._print_test_header() + for i in self._unexpected_success: + self.out.write(self._format_test(self._tests[i])) + self._print_bar() + + def _http_server_command(self, port): + if sys.version_info[0] < 3: + return 'python -m SimpleHTTPServer %d' % port else: - self._log.info('unexpected failure: %s' % test.name) - self._unexpected_failure.append(index) - elif flaky and not known: - self._log.info('unexpected flaky success: %s' % test.name) - self._flaky_success.append(index) - elif not flaky and known: - self._log.info('unexpected success: %s' % test.name) - self._unexpected_success.append(index) - test.success = not failed - test.returncode = returncode - test.retry_count = retry_count - test.expired = expired - test.as_expected = known == failed - if not self.concurrent: - self.out.write(self._result_string(test) + '\n') - else: - self.out.write(self._format_test(test)) - finally: - self._lock.release() + return 'python -m http.server %d' % port + + def _print_footer(self): + fail_count = len(self._expected_failure) + len(self._unexpected_failure) + self._print_bar() + self._print_unexpected_success() + self._print_flaky_success() + self._print_unexpected_failure() + self._write_html_data() + self._assemble_log('unexpected failures', self._unexpected_failure) + self._assemble_log('known failures', self._expected_failure) + self.out.writelines([ + 'You can browse results at:\n', + '\tfile://%s/%s\n' % (self.testdir, RESULT_HTML), + '# If you use Chrome, run:\n', + '# \tcd %s\n#\t%s\n' % (self._basedir, self._http_server_command(8001)), + '# then browse:\n', + '# \thttp://localhost:%d/%s/\n' % (8001, self._testdir_rel), + 'Full log for each test is here:\n', + '\ttest/log/client_server_protocol_transport_client.log\n', + '\ttest/log/client_server_protocol_transport_server.log\n', + '%d failed of %d tests in total.\n' % (fail_count, len(self._tests)), + ]) + self._print_exec_time() + self._print_date() + + def _render_result(self, test): + return [ + test.server.name, + test.client.name, + test.protocol, + test.transport, + test.socket, + test.success, + test.as_expected, + test.returncode, + { + 'server': self.test_logfile(test.name, test.server.kind), + 'client': self.test_logfile(test.name, test.client.kind), + }, + ] + + def _write_html_data(self): + """Writes JSON data to be read by result html""" + results = [self._render_result(r) for r in self._tests] + with logfile_open(self.out_path, 'w+') as fp: + fp.write(json.dumps({ + 'date': self._format_date(), + 'revision': str(self._revision), + 'platform': self._platform, + 'duration': '{:.1f}'.format(self._elapsed), + 'results': results, + }, indent=2)) + + def _assemble_log(self, title, indexes): + if len(indexes) > 0: + def add_prog_log(fp, test, prog_kind): + print('*************************** %s message ***************************' % prog_kind, + file=fp) + path = self.test_logfile(test.name, prog_kind, self.testdir) + if os.path.exists(path): + with logfile_open(path, 'r') as prog_fp: + print(prog_fp.read(), file=fp) + filename = title.replace(' ', '_') + '.log' + with logfile_open(os.path.join(self.logdir, filename), 'w+') as fp: + for test in map(self._tests.__getitem__, indexes): + fp.write('TEST: [%s]\n' % test.name) + add_prog_log(fp, test, test.server.kind) + add_prog_log(fp, test, test.client.kind) + fp.write('**********************************************************************\n\n') + print('%s are logged to %s/%s/%s' % (title.capitalize(), self._testdir_rel, LOG_DIR, filename)) + + def end(self): + self._print_footer() + return len(self._unexpected_failure) == 0 + + def add_test(self, test_dict): + test = TestEntry(self.testdir, **test_dict) + self._lock.acquire() + try: + if not self.concurrent: + self.out.write(self._format_test(test, False)) + self.out.flush() + self._tests.append(test) + return len(self._tests) - 1 + finally: + self._lock.release() + + def add_result(self, index, returncode, expired, retry_count): + self._lock.acquire() + try: + failed = returncode is None or returncode != 0 + flaky = not failed and retry_count != 0 + test = self._tests[index] + known = test.name in self._known_failures + if failed: + if known: + self._log.debug('%s failed as expected' % test.name) + self._expected_failure.append(index) + else: + self._log.info('unexpected failure: %s' % test.name) + self._unexpected_failure.append(index) + elif flaky and not known: + self._log.info('unexpected flaky success: %s' % test.name) + self._flaky_success.append(index) + elif not flaky and known: + self._log.info('unexpected success: %s' % test.name) + self._unexpected_success.append(index) + test.success = not failed + test.returncode = returncode + test.retry_count = retry_count + test.expired = expired + test.as_expected = known == failed + if not self.concurrent: + self.out.write(self._result_string(test) + '\n') + else: + self.out.write(self._format_test(test)) + finally: + self._lock.release() diff --git a/test/crossrunner/run.py b/test/crossrunner/run.py index 68bd92869..18c162357 100644 --- a/test/crossrunner/run.py +++ b/test/crossrunner/run.py @@ -39,307 +39,307 @@ RESULT_ERROR = 64 class ExecutionContext(object): - def __init__(self, cmd, cwd, env, report): - self._log = multiprocessing.get_logger() - self.report = report - self.cmd = cmd - self.cwd = cwd - self.env = env - self.timer = None - self.expired = False - self.killed = False - - def _expire(self): - self._log.info('Timeout') - self.expired = True - self.kill() - - def kill(self): - self._log.debug('Killing process : %d' % self.proc.pid) - self.killed = True - if platform.system() != 'Windows': - try: - os.killpg(self.proc.pid, signal.SIGKILL) - except Exception: - self._log.info('Failed to kill process group', exc_info=sys.exc_info()) - try: - self.proc.kill() - except Exception: - self._log.info('Failed to kill process', exc_info=sys.exc_info()) - - def _popen_args(self): - args = { - 'cwd': self.cwd, - 'env': self.env, - 'stdout': self.report.out, - 'stderr': subprocess.STDOUT, - } - # make sure child processes doesn't remain after killing - if platform.system() == 'Windows': - DETACHED_PROCESS = 0x00000008 - args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP) - else: - args.update(preexec_fn=os.setsid) - return args - - def start(self, timeout=0): - joined = str_join(' ', self.cmd) - self._log.debug('COMMAND: %s', joined) - self._log.debug('WORKDIR: %s', self.cwd) - self._log.debug('LOGFILE: %s', self.report.logpath) - self.report.begin() - self.proc = subprocess.Popen(self.cmd, **self._popen_args()) - if timeout > 0: - self.timer = threading.Timer(timeout, self._expire) - self.timer.start() - return self._scoped() - - @contextlib.contextmanager - def _scoped(self): - yield self - self._log.debug('Killing scoped process') - if self.proc.poll() is None: - self.kill() - self.report.killed() - else: - self._log.debug('Process died unexpectedly') - self.report.died() - - def wait(self): - self.proc.communicate() - if self.timer: - self.timer.cancel() - self.report.end(self.returncode) - - @property - def returncode(self): - return self.proc.returncode if self.proc else None + def __init__(self, cmd, cwd, env, report): + self._log = multiprocessing.get_logger() + self.report = report + self.cmd = cmd + self.cwd = cwd + self.env = env + self.timer = None + self.expired = False + self.killed = False + + def _expire(self): + self._log.info('Timeout') + self.expired = True + self.kill() + + def kill(self): + self._log.debug('Killing process : %d' % self.proc.pid) + self.killed = True + if platform.system() != 'Windows': + try: + os.killpg(self.proc.pid, signal.SIGKILL) + except Exception: + self._log.info('Failed to kill process group', exc_info=sys.exc_info()) + try: + self.proc.kill() + except Exception: + self._log.info('Failed to kill process', exc_info=sys.exc_info()) + + def _popen_args(self): + args = { + 'cwd': self.cwd, + 'env': self.env, + 'stdout': self.report.out, + 'stderr': subprocess.STDOUT, + } + # make sure child processes doesn't remain after killing + if platform.system() == 'Windows': + DETACHED_PROCESS = 0x00000008 + args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP) + else: + args.update(preexec_fn=os.setsid) + return args + + def start(self, timeout=0): + joined = str_join(' ', self.cmd) + self._log.debug('COMMAND: %s', joined) + self._log.debug('WORKDIR: %s', self.cwd) + self._log.debug('LOGFILE: %s', self.report.logpath) + self.report.begin() + self.proc = subprocess.Popen(self.cmd, **self._popen_args()) + if timeout > 0: + self.timer = threading.Timer(timeout, self._expire) + self.timer.start() + return self._scoped() + + @contextlib.contextmanager + def _scoped(self): + yield self + self._log.debug('Killing scoped process') + if self.proc.poll() is None: + self.kill() + self.report.killed() + else: + self._log.debug('Process died unexpectedly') + self.report.died() + + def wait(self): + self.proc.communicate() + if self.timer: + self.timer.cancel() + self.report.end(self.returncode) + + @property + def returncode(self): + return self.proc.returncode if self.proc else None def exec_context(port, logdir, test, prog): - report = ExecReporter(logdir, test, prog) - prog.build_command(port) - return ExecutionContext(prog.command, prog.workdir, prog.env, report) + report = ExecReporter(logdir, test, prog) + prog.build_command(port) + return ExecutionContext(prog.command, prog.workdir, prog.env, report) def run_test(testdir, logdir, test_dict, max_retry, async=True): - try: - logger = multiprocessing.get_logger() - max_bind_retry = 3 - retry_count = 0 - bind_retry_count = 0 - test = TestEntry(testdir, **test_dict) - while True: - if stop.is_set(): - logger.debug('Skipping because shutting down') - return (retry_count, None) - logger.debug('Start') - with PortAllocator.alloc_port_scoped(ports, test.socket) as port: - logger.debug('Start with port %d' % port) - sv = exec_context(port, logdir, test, test.server) - cl = exec_context(port, logdir, test, test.client) - - logger.debug('Starting server') - with sv.start(): - if test.delay > 0: - logger.debug('Delaying client for %.2f seconds' % test.delay) - time.sleep(test.delay) - connect_retry_count = 0 - max_connect_retry = 10 - connect_retry_wait = 0.5 - while True: - logger.debug('Starting client') - cl.start(test.timeout) - logger.debug('Waiting client') - cl.wait() - if not cl.report.maybe_false_positive() or connect_retry_count >= max_connect_retry: - if connect_retry_count > 0 and connect_retry_count < max_connect_retry: - logger.warn('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, connect_retry_count, connect_retry_wait)) - # Wait for 50ms to see if server does not die at the end. - time.sleep(0.05) - break - logger.debug('Server may not be ready, waiting %.2f second...' % connect_retry_wait) - time.sleep(connect_retry_wait) - connect_retry_count += 1 - - if sv.report.maybe_false_positive() and bind_retry_count < max_bind_retry: - logger.warn('[%s]: Detected socket bind failure, retrying...', test.server.name) - bind_retry_count += 1 - else: - if cl.expired: - result = RESULT_TIMEOUT - elif not sv.killed and cl.proc.returncode == 0: - # Server should be alive at the end. - result = RESULT_ERROR - else: - result = cl.proc.returncode - - if result == 0 or retry_count >= max_retry: - return (retry_count, result) - else: - logger.info('[%s-%s]: test failed, retrying...', test.server.name, test.client.name) - retry_count += 1 - except (KeyboardInterrupt, SystemExit): - logger.info('Interrupted execution') - if not async: - raise - stop.set() - return None - except: - if not async: - raise - logger.warn('Error executing [%s]', test.name, exc_info=sys.exc_info()) - return (retry_count, RESULT_ERROR) + try: + logger = multiprocessing.get_logger() + max_bind_retry = 3 + retry_count = 0 + bind_retry_count = 0 + test = TestEntry(testdir, **test_dict) + while True: + if stop.is_set(): + logger.debug('Skipping because shutting down') + return (retry_count, None) + logger.debug('Start') + with PortAllocator.alloc_port_scoped(ports, test.socket) as port: + logger.debug('Start with port %d' % port) + sv = exec_context(port, logdir, test, test.server) + cl = exec_context(port, logdir, test, test.client) + + logger.debug('Starting server') + with sv.start(): + if test.delay > 0: + logger.debug('Delaying client for %.2f seconds' % test.delay) + time.sleep(test.delay) + connect_retry_count = 0 + max_connect_retry = 10 + connect_retry_wait = 0.5 + while True: + logger.debug('Starting client') + cl.start(test.timeout) + logger.debug('Waiting client') + cl.wait() + if not cl.report.maybe_false_positive() or connect_retry_count >= max_connect_retry: + if connect_retry_count > 0 and connect_retry_count < max_connect_retry: + logger.warn('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, connect_retry_count, connect_retry_wait)) + # Wait for 50ms to see if server does not die at the end. + time.sleep(0.05) + break + logger.debug('Server may not be ready, waiting %.2f second...' % connect_retry_wait) + time.sleep(connect_retry_wait) + connect_retry_count += 1 + + if sv.report.maybe_false_positive() and bind_retry_count < max_bind_retry: + logger.warn('[%s]: Detected socket bind failure, retrying...', test.server.name) + bind_retry_count += 1 + else: + if cl.expired: + result = RESULT_TIMEOUT + elif not sv.killed and cl.proc.returncode == 0: + # Server should be alive at the end. + result = RESULT_ERROR + else: + result = cl.proc.returncode + + if result == 0 or retry_count >= max_retry: + return (retry_count, result) + else: + logger.info('[%s-%s]: test failed, retrying...', test.server.name, test.client.name) + retry_count += 1 + except (KeyboardInterrupt, SystemExit): + logger.info('Interrupted execution') + if not async: + raise + stop.set() + return None + except: + if not async: + raise + logger.warn('Error executing [%s]', test.name, exc_info=sys.exc_info()) + return (retry_count, RESULT_ERROR) class PortAllocator(object): - def __init__(self): - self._log = multiprocessing.get_logger() - self._lock = multiprocessing.Lock() - self._ports = set() - self._dom_ports = set() - self._last_alloc = 0 - - def _get_tcp_port(self): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('127.0.0.1', 0)) - port = sock.getsockname()[1] - self._lock.acquire() - try: - ok = port not in self._ports - if ok: - self._ports.add(port) - self._last_alloc = time.time() - finally: - self._lock.release() - sock.close() - return port if ok else self._get_tcp_port() - - def _get_domain_port(self): - port = random.randint(1024, 65536) - self._lock.acquire() - try: - ok = port not in self._dom_ports - if ok: - self._dom_ports.add(port) - finally: - self._lock.release() - return port if ok else self._get_domain_port() - - def alloc_port(self, socket_type): - if socket_type in ('domain', 'abstract'): - return self._get_domain_port() - else: - return self._get_tcp_port() - - # static method for inter-process invokation - @staticmethod - @contextlib.contextmanager - def alloc_port_scoped(allocator, socket_type): - port = allocator.alloc_port(socket_type) - yield port - allocator.free_port(socket_type, port) - - def free_port(self, socket_type, port): - self._log.debug('free_port') - self._lock.acquire() - try: - if socket_type == 'domain': - self._dom_ports.remove(port) - path = domain_socket_path(port) - if os.path.exists(path): - os.remove(path) - elif socket_type == 'abstract': - self._dom_ports.remove(port) - else: - self._ports.remove(port) - except IOError: - self._log.info('Error while freeing port', exc_info=sys.exc_info()) - finally: - self._lock.release() + def __init__(self): + self._log = multiprocessing.get_logger() + self._lock = multiprocessing.Lock() + self._ports = set() + self._dom_ports = set() + self._last_alloc = 0 + + def _get_tcp_port(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('127.0.0.1', 0)) + port = sock.getsockname()[1] + self._lock.acquire() + try: + ok = port not in self._ports + if ok: + self._ports.add(port) + self._last_alloc = time.time() + finally: + self._lock.release() + sock.close() + return port if ok else self._get_tcp_port() + + def _get_domain_port(self): + port = random.randint(1024, 65536) + self._lock.acquire() + try: + ok = port not in self._dom_ports + if ok: + self._dom_ports.add(port) + finally: + self._lock.release() + return port if ok else self._get_domain_port() + + def alloc_port(self, socket_type): + if socket_type in ('domain', 'abstract'): + return self._get_domain_port() + else: + return self._get_tcp_port() + + # static method for inter-process invokation + @staticmethod + @contextlib.contextmanager + def alloc_port_scoped(allocator, socket_type): + port = allocator.alloc_port(socket_type) + yield port + allocator.free_port(socket_type, port) + + def free_port(self, socket_type, port): + self._log.debug('free_port') + self._lock.acquire() + try: + if socket_type == 'domain': + self._dom_ports.remove(port) + path = domain_socket_path(port) + if os.path.exists(path): + os.remove(path) + elif socket_type == 'abstract': + self._dom_ports.remove(port) + else: + self._ports.remove(port) + except IOError: + self._log.info('Error while freeing port', exc_info=sys.exc_info()) + finally: + self._lock.release() class NonAsyncResult(object): - def __init__(self, value): - self._value = value + def __init__(self, value): + self._value = value - def get(self, timeout=None): - return self._value + def get(self, timeout=None): + return self._value - def wait(self, timeout=None): - pass + def wait(self, timeout=None): + pass - def ready(self): - return True + def ready(self): + return True - def successful(self): - return self._value == 0 + def successful(self): + return self._value == 0 class TestDispatcher(object): - def __init__(self, testdir, basedir, logdir_rel, concurrency): - self._log = multiprocessing.get_logger() - self.testdir = testdir - self._report = SummaryReporter(basedir, logdir_rel, concurrency > 1) - self.logdir = self._report.testdir - # seems needed for python 2.x to handle keyboard interrupt - self._stop = multiprocessing.Event() - self._async = concurrency > 1 - if not self._async: - self._pool = None - global stop - global ports - stop = self._stop - ports = PortAllocator() - else: - self._m = multiprocessing.managers.BaseManager() - self._m.register('ports', PortAllocator) - self._m.start() - self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,)) - self._log.debug( - 'TestDispatcher started with %d concurrent jobs' % concurrency) - - def _pool_init(self, address): - global stop - global m - global ports - stop = self._stop - m = multiprocessing.managers.BaseManager(address) - m.connect() - ports = m.ports() - - def _dispatch_sync(self, test, cont, max_retry): - r = run_test(self.testdir, self.logdir, test, max_retry, False) - cont(r) - return NonAsyncResult(r) - - def _dispatch_async(self, test, cont, max_retry): - self._log.debug('_dispatch_async') - return self._pool.apply_async(func=run_test, args=(self.testdir, self.logdir, test, max_retry), callback=cont) - - def dispatch(self, test, max_retry): - index = self._report.add_test(test) - - def cont(result): - if not self._stop.is_set(): - retry_count, returncode = result - self._log.debug('freeing port') - self._log.debug('adding result') - self._report.add_result(index, returncode, returncode == RESULT_TIMEOUT, retry_count) - self._log.debug('finish continuation') - fn = self._dispatch_async if self._async else self._dispatch_sync - return fn(test, cont, max_retry) - - def wait(self): - if self._async: - self._pool.close() - self._pool.join() - self._m.shutdown() - return self._report.end() - - def terminate(self): - self._stop.set() - if self._async: - self._pool.terminate() - self._pool.join() - self._m.shutdown() + def __init__(self, testdir, basedir, logdir_rel, concurrency): + self._log = multiprocessing.get_logger() + self.testdir = testdir + self._report = SummaryReporter(basedir, logdir_rel, concurrency > 1) + self.logdir = self._report.testdir + # seems needed for python 2.x to handle keyboard interrupt + self._stop = multiprocessing.Event() + self._async = concurrency > 1 + if not self._async: + self._pool = None + global stop + global ports + stop = self._stop + ports = PortAllocator() + else: + self._m = multiprocessing.managers.BaseManager() + self._m.register('ports', PortAllocator) + self._m.start() + self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,)) + self._log.debug( + 'TestDispatcher started with %d concurrent jobs' % concurrency) + + def _pool_init(self, address): + global stop + global m + global ports + stop = self._stop + m = multiprocessing.managers.BaseManager(address) + m.connect() + ports = m.ports() + + def _dispatch_sync(self, test, cont, max_retry): + r = run_test(self.testdir, self.logdir, test, max_retry, False) + cont(r) + return NonAsyncResult(r) + + def _dispatch_async(self, test, cont, max_retry): + self._log.debug('_dispatch_async') + return self._pool.apply_async(func=run_test, args=(self.testdir, self.logdir, test, max_retry), callback=cont) + + def dispatch(self, test, max_retry): + index = self._report.add_test(test) + + def cont(result): + if not self._stop.is_set(): + retry_count, returncode = result + self._log.debug('freeing port') + self._log.debug('adding result') + self._report.add_result(index, returncode, returncode == RESULT_TIMEOUT, retry_count) + self._log.debug('finish continuation') + fn = self._dispatch_async if self._async else self._dispatch_sync + return fn(test, cont, max_retry) + + def wait(self): + if self._async: + self._pool.close() + self._pool.join() + self._m.shutdown() + return self._report.end() + + def terminate(self): + self._stop.set() + if self._async: + self._pool.terminate() + self._pool.join() + self._m.shutdown() diff --git a/test/crossrunner/test.py b/test/crossrunner/test.py index fc90f7f30..dcc8a9416 100644 --- a/test/crossrunner/test.py +++ b/test/crossrunner/test.py @@ -26,118 +26,118 @@ from .util import merge_dict def domain_socket_path(port): - return '/tmp/ThriftTest.thrift.%d' % port + return '/tmp/ThriftTest.thrift.%d' % port class TestProgram(object): - def __init__(self, kind, name, protocol, transport, socket, workdir, command, env=None, - extra_args=[], extra_args2=[], join_args=False, **kwargs): - self.kind = kind - self.name = name - self.protocol = protocol - self.transport = transport - self.socket = socket - self.workdir = workdir - self.command = None - self._base_command = self._fix_cmd_path(command) - if env: - self.env = copy.copy(os.environ) - self.env.update(env) - else: - self.env = os.environ - self._extra_args = extra_args - self._extra_args2 = extra_args2 - self._join_args = join_args - - def _fix_cmd_path(self, cmd): - # if the arg is a file in the current directory, make it path - def abs_if_exists(arg): - p = path_join(self.workdir, arg) - return p if os.path.exists(p) else arg - - if cmd[0] == 'python': - cmd[0] = sys.executable - else: - cmd[0] = abs_if_exists(cmd[0]) - return cmd - - def _socket_args(self, socket, port): - return { - 'ip-ssl': ['--ssl'], - 'domain': ['--domain-socket=%s' % domain_socket_path(port)], - 'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)], - }.get(socket, None) - - def build_command(self, port): - cmd = copy.copy(self._base_command) - args = copy.copy(self._extra_args2) - args.append('--protocol=' + self.protocol) - args.append('--transport=' + self.transport) - socket_args = self._socket_args(self.socket, port) - if socket_args: - args += socket_args - args.append('--port=%d' % port) - if self._join_args: - cmd.append('%s' % " ".join(args)) - else: - cmd.extend(args) - if self._extra_args: - cmd.extend(self._extra_args) - self.command = cmd - return self.command + def __init__(self, kind, name, protocol, transport, socket, workdir, command, env=None, + extra_args=[], extra_args2=[], join_args=False, **kwargs): + self.kind = kind + self.name = name + self.protocol = protocol + self.transport = transport + self.socket = socket + self.workdir = workdir + self.command = None + self._base_command = self._fix_cmd_path(command) + if env: + self.env = copy.copy(os.environ) + self.env.update(env) + else: + self.env = os.environ + self._extra_args = extra_args + self._extra_args2 = extra_args2 + self._join_args = join_args + + def _fix_cmd_path(self, cmd): + # if the arg is a file in the current directory, make it path + def abs_if_exists(arg): + p = path_join(self.workdir, arg) + return p if os.path.exists(p) else arg + + if cmd[0] == 'python': + cmd[0] = sys.executable + else: + cmd[0] = abs_if_exists(cmd[0]) + return cmd + + def _socket_args(self, socket, port): + return { + 'ip-ssl': ['--ssl'], + 'domain': ['--domain-socket=%s' % domain_socket_path(port)], + 'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)], + }.get(socket, None) + + def build_command(self, port): + cmd = copy.copy(self._base_command) + args = copy.copy(self._extra_args2) + args.append('--protocol=' + self.protocol) + args.append('--transport=' + self.transport) + socket_args = self._socket_args(self.socket, port) + if socket_args: + args += socket_args + args.append('--port=%d' % port) + if self._join_args: + cmd.append('%s' % " ".join(args)) + else: + cmd.extend(args) + if self._extra_args: + cmd.extend(self._extra_args) + self.command = cmd + return self.command class TestEntry(object): - def __init__(self, testdir, server, client, delay, timeout, **kwargs): - self.testdir = testdir - self._log = multiprocessing.get_logger() - self._config = kwargs - self.protocol = kwargs['protocol'] - self.transport = kwargs['transport'] - self.socket = kwargs['socket'] - srv_dict = self._fix_workdir(merge_dict(self._config, server)) - cli_dict = self._fix_workdir(merge_dict(self._config, client)) - cli_dict['extra_args2'] = srv_dict.pop('remote_args', []) - srv_dict['extra_args2'] = cli_dict.pop('remote_args', []) - self.server = TestProgram('server', **srv_dict) - self.client = TestProgram('client', **cli_dict) - self.delay = delay - self.timeout = timeout - self._name = None - # results - self.success = None - self.as_expected = None - self.returncode = None - self.expired = False - self.retry_count = 0 - - def _fix_workdir(self, config): - key = 'workdir' - path = config.get(key, None) - if not path: - path = self.testdir - if os.path.isabs(path): - path = os.path.realpath(path) - else: - path = os.path.realpath(path_join(self.testdir, path)) - config.update({key: path}) - return config - - @classmethod - def get_name(cls, server, client, proto, trans, sock, *args): - return '%s-%s_%s_%s-%s' % (server, client, proto, trans, sock) - - @property - def name(self): - if not self._name: - self._name = self.get_name( - self.server.name, self.client.name, self.protocol, self.transport, self.socket) - return self._name - - @property - def transport_name(self): - return '%s-%s' % (self.transport, self.socket) + def __init__(self, testdir, server, client, delay, timeout, **kwargs): + self.testdir = testdir + self._log = multiprocessing.get_logger() + self._config = kwargs + self.protocol = kwargs['protocol'] + self.transport = kwargs['transport'] + self.socket = kwargs['socket'] + srv_dict = self._fix_workdir(merge_dict(self._config, server)) + cli_dict = self._fix_workdir(merge_dict(self._config, client)) + cli_dict['extra_args2'] = srv_dict.pop('remote_args', []) + srv_dict['extra_args2'] = cli_dict.pop('remote_args', []) + self.server = TestProgram('server', **srv_dict) + self.client = TestProgram('client', **cli_dict) + self.delay = delay + self.timeout = timeout + self._name = None + # results + self.success = None + self.as_expected = None + self.returncode = None + self.expired = False + self.retry_count = 0 + + def _fix_workdir(self, config): + key = 'workdir' + path = config.get(key, None) + if not path: + path = self.testdir + if os.path.isabs(path): + path = os.path.realpath(path) + else: + path = os.path.realpath(path_join(self.testdir, path)) + config.update({key: path}) + return config + + @classmethod + def get_name(cls, server, client, proto, trans, sock, *args): + return '%s-%s_%s_%s-%s' % (server, client, proto, trans, sock) + + @property + def name(self): + if not self._name: + self._name = self.get_name( + self.server.name, self.client.name, self.protocol, self.transport, self.socket) + return self._name + + @property + def transport_name(self): + return '%s-%s' % (self.transport, self.socket) def test_name(server, client, protocol, transport, socket, **kwargs): - return TestEntry.get_name(server['name'], client['name'], protocol, transport, socket) + return TestEntry.get_name(server['name'], client['name'], protocol, transport, socket) diff --git a/test/crossrunner/util.py b/test/crossrunner/util.py index 750ed475e..e2d195a22 100644 --- a/test/crossrunner/util.py +++ b/test/crossrunner/util.py @@ -21,11 +21,11 @@ import copy def merge_dict(base, update): - """Update dict concatenating list values""" - res = copy.deepcopy(base) - for k, v in list(update.items()): - if k in list(res.keys()) and isinstance(v, list): - res[k].extend(v) - else: - res[k] = v - return res + """Update dict concatenating list values""" + res = copy.deepcopy(base) + for k, v in list(update.items()): + if k in list(res.keys()) and isinstance(v, list): + res[k].extend(v) + else: + res[k] = v + return res diff --git a/test/features/container_limit.py b/test/features/container_limit.py index 4a7da6065..beed0c5ec 100644 --- a/test/features/container_limit.py +++ b/test/features/container_limit.py @@ -10,63 +10,63 @@ from thrift.Thrift import TMessageType, TType # TODO: generate from ThriftTest.thrift def test_list(proto, value): - method_name = 'testList' - ttype = TType.LIST - etype = TType.I32 - proto.writeMessageBegin(method_name, TMessageType.CALL, 3) - proto.writeStructBegin(method_name + '_args') - proto.writeFieldBegin('thing', ttype, 1) - proto.writeListBegin(etype, len(value)) - for e in value: - proto.writeI32(e) - proto.writeListEnd() - proto.writeFieldEnd() - proto.writeFieldStop() - proto.writeStructEnd() - proto.writeMessageEnd() - proto.trans.flush() + method_name = 'testList' + ttype = TType.LIST + etype = TType.I32 + proto.writeMessageBegin(method_name, TMessageType.CALL, 3) + proto.writeStructBegin(method_name + '_args') + proto.writeFieldBegin('thing', ttype, 1) + proto.writeListBegin(etype, len(value)) + for e in value: + proto.writeI32(e) + proto.writeListEnd() + proto.writeFieldEnd() + proto.writeFieldStop() + proto.writeStructEnd() + proto.writeMessageEnd() + proto.trans.flush() - _, mtype, _ = proto.readMessageBegin() - assert mtype == TMessageType.REPLY - proto.readStructBegin() - _, ftype, fid = proto.readFieldBegin() - assert fid == 0 - assert ftype == ttype - etype2, len2 = proto.readListBegin() - assert etype == etype2 - assert len2 == len(value) - for i in range(len2): - v = proto.readI32() - assert v == value[i] - proto.readListEnd() - proto.readFieldEnd() - _, ftype, _ = proto.readFieldBegin() - assert ftype == TType.STOP - proto.readStructEnd() - proto.readMessageEnd() + _, mtype, _ = proto.readMessageBegin() + assert mtype == TMessageType.REPLY + proto.readStructBegin() + _, ftype, fid = proto.readFieldBegin() + assert fid == 0 + assert ftype == ttype + etype2, len2 = proto.readListBegin() + assert etype == etype2 + assert len2 == len(value) + for i in range(len2): + v = proto.readI32() + assert v == value[i] + proto.readListEnd() + proto.readFieldEnd() + _, ftype, _ = proto.readFieldBegin() + assert ftype == TType.STOP + proto.readStructEnd() + proto.readMessageEnd() def main(argv): - p = argparse.ArgumentParser() - add_common_args(p) - p.add_argument('--limit', type=int) - args = p.parse_args() - proto = init_protocol(args) - # TODO: test set and map - test_list(proto, list(range(args.limit - 1))) - test_list(proto, list(range(args.limit - 1))) - print('[OK]: limit - 1') - test_list(proto, list(range(args.limit))) - test_list(proto, list(range(args.limit))) - print('[OK]: just limit') - try: - test_list(proto, list(range(args.limit + 1))) - except: - print('[OK]: limit + 1') - else: - print('[ERROR]: limit + 1') - assert False + p = argparse.ArgumentParser() + add_common_args(p) + p.add_argument('--limit', type=int) + args = p.parse_args() + proto = init_protocol(args) + # TODO: test set and map + test_list(proto, list(range(args.limit - 1))) + test_list(proto, list(range(args.limit - 1))) + print('[OK]: limit - 1') + test_list(proto, list(range(args.limit))) + test_list(proto, list(range(args.limit))) + print('[OK]: just limit') + try: + test_list(proto, list(range(args.limit + 1))) + except: + print('[OK]: limit + 1') + else: + print('[ERROR]: limit + 1') + assert False if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) + sys.exit(main(sys.argv[1:])) diff --git a/test/features/local_thrift/__init__.py b/test/features/local_thrift/__init__.py index 383ee5f40..0a0bb0b66 100644 --- a/test/features/local_thrift/__init__.py +++ b/test/features/local_thrift/__init__.py @@ -5,10 +5,10 @@ SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) if sys.version_info[0] == 2: - import glob - libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0] - sys.path.insert(0, libdir) - thrift = __import__('thrift') + import glob + libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0] + sys.path.insert(0, libdir) + thrift = __import__('thrift') else: - sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib')) - thrift = __import__('thrift') + sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib')) + thrift = __import__('thrift') diff --git a/test/features/string_limit.py b/test/features/string_limit.py index b4d48acdb..3c68b3ea3 100644 --- a/test/features/string_limit.py +++ b/test/features/string_limit.py @@ -10,52 +10,52 @@ from thrift.Thrift import TMessageType, TType # TODO: generate from ThriftTest.thrift def test_string(proto, value): - method_name = 'testString' - ttype = TType.STRING - proto.writeMessageBegin(method_name, TMessageType.CALL, 3) - proto.writeStructBegin(method_name + '_args') - proto.writeFieldBegin('thing', ttype, 1) - proto.writeString(value) - proto.writeFieldEnd() - proto.writeFieldStop() - proto.writeStructEnd() - proto.writeMessageEnd() - proto.trans.flush() - - _, mtype, _ = proto.readMessageBegin() - assert mtype == TMessageType.REPLY - proto.readStructBegin() - _, ftype, fid = proto.readFieldBegin() - assert fid == 0 - assert ftype == ttype - result = proto.readString() - proto.readFieldEnd() - _, ftype, _ = proto.readFieldBegin() - assert ftype == TType.STOP - proto.readStructEnd() - proto.readMessageEnd() - assert value == result + method_name = 'testString' + ttype = TType.STRING + proto.writeMessageBegin(method_name, TMessageType.CALL, 3) + proto.writeStructBegin(method_name + '_args') + proto.writeFieldBegin('thing', ttype, 1) + proto.writeString(value) + proto.writeFieldEnd() + proto.writeFieldStop() + proto.writeStructEnd() + proto.writeMessageEnd() + proto.trans.flush() + + _, mtype, _ = proto.readMessageBegin() + assert mtype == TMessageType.REPLY + proto.readStructBegin() + _, ftype, fid = proto.readFieldBegin() + assert fid == 0 + assert ftype == ttype + result = proto.readString() + proto.readFieldEnd() + _, ftype, _ = proto.readFieldBegin() + assert ftype == TType.STOP + proto.readStructEnd() + proto.readMessageEnd() + assert value == result def main(argv): - p = argparse.ArgumentParser() - add_common_args(p) - p.add_argument('--limit', type=int) - args = p.parse_args() - proto = init_protocol(args) - test_string(proto, 'a' * (args.limit - 1)) - test_string(proto, 'a' * (args.limit - 1)) - print('[OK]: limit - 1') - test_string(proto, 'a' * args.limit) - test_string(proto, 'a' * args.limit) - print('[OK]: just limit') - try: - test_string(proto, 'a' * (args.limit + 1)) - except: - print('[OK]: limit + 1') - else: - print('[ERROR]: limit + 1') - assert False + p = argparse.ArgumentParser() + add_common_args(p) + p.add_argument('--limit', type=int) + args = p.parse_args() + proto = init_protocol(args) + test_string(proto, 'a' * (args.limit - 1)) + test_string(proto, 'a' * (args.limit - 1)) + print('[OK]: limit - 1') + test_string(proto, 'a' * args.limit) + test_string(proto, 'a' * args.limit) + print('[OK]: just limit') + try: + test_string(proto, 'a' * (args.limit + 1)) + except: + print('[OK]: limit + 1') + else: + print('[ERROR]: limit + 1') + assert False if __name__ == '__main__': - main(sys.argv[1:]) + main(sys.argv[1:]) diff --git a/test/features/theader_binary.py b/test/features/theader_binary.py index 62a26715d..02e010b8b 100644 --- a/test/features/theader_binary.py +++ b/test/features/theader_binary.py @@ -14,57 +14,57 @@ from thrift.protocol.TCompactProtocol import TCompactProtocol def test_void(proto): - proto.writeMessageBegin('testVoid', TMessageType.CALL, 3) - proto.writeStructBegin('testVoid_args') - proto.writeFieldStop() - proto.writeStructEnd() - proto.writeMessageEnd() - proto.trans.flush() + proto.writeMessageBegin('testVoid', TMessageType.CALL, 3) + proto.writeStructBegin('testVoid_args') + proto.writeFieldStop() + proto.writeStructEnd() + proto.writeMessageEnd() + proto.trans.flush() - _, mtype, _ = proto.readMessageBegin() - assert mtype == TMessageType.REPLY - proto.readStructBegin() - _, ftype, _ = proto.readFieldBegin() - assert ftype == TType.STOP - proto.readStructEnd() - proto.readMessageEnd() + _, mtype, _ = proto.readMessageBegin() + assert mtype == TMessageType.REPLY + proto.readStructBegin() + _, ftype, _ = proto.readFieldBegin() + assert ftype == TType.STOP + proto.readStructEnd() + proto.readMessageEnd() # THeader stack should accept binary protocol with optionally framed transport def main(argv): - p = argparse.ArgumentParser() - add_common_args(p) - # Since THeaderTransport acts as framed transport when detected frame, we - # cannot use --transport=framed as it would result in 2 layered frames. - p.add_argument('--override-transport') - p.add_argument('--override-protocol') - args = p.parse_args() - assert args.protocol == 'header' - assert args.transport == 'buffered' - assert not args.ssl + p = argparse.ArgumentParser() + add_common_args(p) + # Since THeaderTransport acts as framed transport when detected frame, we + # cannot use --transport=framed as it would result in 2 layered frames. + p.add_argument('--override-transport') + p.add_argument('--override-protocol') + args = p.parse_args() + assert args.protocol == 'header' + assert args.transport == 'buffered' + assert not args.ssl - sock = TSocket(args.host, args.port, socket_family=socket.AF_INET) - if not args.override_transport or args.override_transport == 'buffered': - trans = TBufferedTransport(sock) - elif args.override_transport == 'framed': - print('TFRAMED') - trans = TFramedTransport(sock) - else: - raise ValueError('invalid transport') - trans.open() + sock = TSocket(args.host, args.port, socket_family=socket.AF_INET) + if not args.override_transport or args.override_transport == 'buffered': + trans = TBufferedTransport(sock) + elif args.override_transport == 'framed': + print('TFRAMED') + trans = TFramedTransport(sock) + else: + raise ValueError('invalid transport') + trans.open() - if not args.override_protocol or args.override_protocol == 'binary': - proto = TBinaryProtocol(trans) - elif args.override_protocol == 'compact': - proto = TCompactProtocol(trans) - else: - raise ValueError('invalid transport') + if not args.override_protocol or args.override_protocol == 'binary': + proto = TBinaryProtocol(trans) + elif args.override_protocol == 'compact': + proto = TCompactProtocol(trans) + else: + raise ValueError('invalid transport') - test_void(proto) - test_void(proto) + test_void(proto) + test_void(proto) - trans.close() + trans.close() if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) + sys.exit(main(sys.argv[1:])) diff --git a/test/features/util.py b/test/features/util.py index e36413629..e4997d0b7 100644 --- a/test/features/util.py +++ b/test/features/util.py @@ -11,30 +11,30 @@ from thrift.protocol.TJSONProtocol import TJSONProtocol def add_common_args(p): - p.add_argument('--host', default='localhost') - p.add_argument('--port', type=int, default=9090) - p.add_argument('--protocol', default='binary') - p.add_argument('--transport', default='buffered') - p.add_argument('--ssl', action='store_true') + p.add_argument('--host', default='localhost') + p.add_argument('--port', type=int, default=9090) + p.add_argument('--protocol', default='binary') + p.add_argument('--transport', default='buffered') + p.add_argument('--ssl', action='store_true') def parse_common_args(argv): - p = argparse.ArgumentParser() - add_common_args(p) - return p.parse_args(argv) + p = argparse.ArgumentParser() + add_common_args(p) + return p.parse_args(argv) def init_protocol(args): - sock = TSocket(args.host, args.port, socket_family=socket.AF_INET) - sock.setTimeout(500) - trans = { - 'buffered': TBufferedTransport, - 'framed': TFramedTransport, - 'http': THttpClient, - }[args.transport](sock) - trans.open() - return { - 'binary': TBinaryProtocol, - 'compact': TCompactProtocol, - 'json': TJSONProtocol, - }[args.protocol](trans) + sock = TSocket(args.host, args.port, socket_family=socket.AF_INET) + sock.setTimeout(500) + trans = { + 'buffered': TBufferedTransport, + 'framed': TFramedTransport, + 'http': THttpClient, + }[args.transport](sock) + trans.open() + return { + 'binary': TBinaryProtocol, + 'compact': TCompactProtocol, + 'json': TJSONProtocol, + }[args.protocol](trans) diff --git a/test/py.tornado/test_suite.py b/test/py.tornado/test_suite.py index e0bf91356..b9ce78181 100755 --- a/test/py.tornado/test_suite.py +++ b/test/py.tornado/test_suite.py @@ -27,7 +27,7 @@ import time import unittest basepath = os.path.abspath(os.path.dirname(__file__)) -sys.path.insert(0, basepath+'/gen-py.tornado') +sys.path.insert(0, basepath + '/gen-py.tornado') sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib*'))[0]) try: diff --git a/test/py.twisted/test_suite.py b/test/py.twisted/test_suite.py index 2c07baaf8..3a59bb1f1 100755 --- a/test/py.twisted/test_suite.py +++ b/test/py.twisted/test_suite.py @@ -19,7 +19,10 @@ # under the License. # -import sys, os, glob, time +import sys +import os +import glob +import time basepath = os.path.abspath(os.path.dirname(__file__)) sys.path.insert(0, os.path.join(basepath, 'gen-py.twisted')) sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib.*'))[0]) @@ -35,6 +38,7 @@ from twisted.internet.protocol import ClientCreator from zope.interface import implements + class TestHandler: implements(ThriftTest.Iface) @@ -100,6 +104,7 @@ class TestHandler: def testTypedef(self, thing): return thing + class ThriftTestCase(unittest.TestCase): @defer.inlineCallbacks @@ -109,15 +114,15 @@ class ThriftTestCase(unittest.TestCase): self.pfactory = TBinaryProtocol.TBinaryProtocolFactory() self.server = reactor.listenTCP(0, - TTwisted.ThriftServerFactory(self.processor, - self.pfactory), interface="127.0.0.1") + TTwisted.ThriftServerFactory(self.processor, + self.pfactory), interface="127.0.0.1") self.portNo = self.server.getHost().port self.txclient = yield ClientCreator(reactor, - TTwisted.ThriftClientProtocol, - ThriftTest.Client, - self.pfactory).connectTCP("127.0.0.1", self.portNo) + TTwisted.ThriftClientProtocol, + ThriftTest.Client, + self.pfactory).connectTCP("127.0.0.1", self.portNo) self.client = self.txclient.client @defer.inlineCallbacks @@ -179,7 +184,7 @@ class ThriftTestCase(unittest.TestCase): try: yield self.client.testException("throw_undeclared") self.fail("should have thrown exception") - except Exception: # type is undefined + except Exception: # type is undefined pass @defer.inlineCallbacks diff --git a/test/py/FastbinaryTest.py b/test/py/FastbinaryTest.py index 9d258fdbf..a8718dce1 100755 --- a/test/py/FastbinaryTest.py +++ b/test/py/FastbinaryTest.py @@ -41,11 +41,11 @@ from DebugProtoTest.ttypes import Backwards, Bonk, Empty, HolyMoley, OneOfEach, class TDevNullTransport(TTransport.TTransportBase): - def __init__(self): - pass + def __init__(self): + pass - def isOpen(self): - return True + def isOpen(self): + return True ooe1 = OneOfEach() ooe1.im_true = True @@ -71,8 +71,8 @@ ooe2.zomg_unicode = u"\xd3\x80\xe2\x85\xae\xce\x9d\x20"\ u"\xc7\x83\xe2\x80\xbc" if sys.version_info[0] == 2 and os.environ.get('THRIFT_TEST_PY_NO_UTF8STRINGS'): - ooe1.zomg_unicode = ooe1.zomg_unicode.encode('utf8') - ooe2.zomg_unicode = ooe2.zomg_unicode.encode('utf8') + ooe1.zomg_unicode = ooe1.zomg_unicode.encode('utf8') + ooe2.zomg_unicode = ooe2.zomg_unicode.encode('utf8') hm = HolyMoley(**{"big": [], "contain": set(), "bonks": {}}) hm.big.append(ooe1) @@ -86,13 +86,13 @@ hm.contain.add(()) hm.bonks["nothing"] = [] hm.bonks["something"] = [ - Bonk(**{"type": 1, "message": "Wait."}), - Bonk(**{"type": 2, "message": "What?"}), + Bonk(**{"type": 1, "message": "Wait."}), + Bonk(**{"type": 2, "message": "What?"}), ] hm.bonks["poe"] = [ - Bonk(**{"type": 3, "message": "quoth"}), - Bonk(**{"type": 4, "message": "the raven"}), - Bonk(**{"type": 5, "message": "nevermore"}), + Bonk(**{"type": 3, "message": "quoth"}), + Bonk(**{"type": 4, "message": "the raven"}), + Bonk(**{"type": 5, "message": "nevermore"}), ] rs = RandomStuff() @@ -112,110 +112,110 @@ my_zero = Srv.Janky_result(**{"success": 5}) def check_write(o): - trans_fast = TTransport.TMemoryBuffer() - trans_slow = TTransport.TMemoryBuffer() - prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) - prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) + trans_fast = TTransport.TMemoryBuffer() + trans_slow = TTransport.TMemoryBuffer() + prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) + prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) - o.write(prot_fast) - o.write(prot_slow) - ORIG = trans_slow.getvalue() - MINE = trans_fast.getvalue() - if ORIG != MINE: - print("mine: %s\norig: %s" % (repr(MINE), repr(ORIG))) + o.write(prot_fast) + o.write(prot_slow) + ORIG = trans_slow.getvalue() + MINE = trans_fast.getvalue() + if ORIG != MINE: + print("mine: %s\norig: %s" % (repr(MINE), repr(ORIG))) def check_read(o): - prot = TBinaryProtocol.TBinaryProtocol(TTransport.TMemoryBuffer()) - o.write(prot) - - slow_version_binary = prot.trans.getvalue() - - prot = TBinaryProtocol.TBinaryProtocolAccelerated( - TTransport.TMemoryBuffer(slow_version_binary)) - c = o.__class__() - c.read(prot) - if c != o: - print("copy: ") - pprint(eval(repr(c))) - print("orig: ") - pprint(eval(repr(o))) - - prot = TBinaryProtocol.TBinaryProtocolAccelerated( - TTransport.TBufferedTransport( - TTransport.TMemoryBuffer(slow_version_binary))) - c = o.__class__() - c.read(prot) - if c != o: - print("copy: ") - pprint(eval(repr(c))) - print("orig: ") - pprint(eval(repr(o))) + prot = TBinaryProtocol.TBinaryProtocol(TTransport.TMemoryBuffer()) + o.write(prot) + + slow_version_binary = prot.trans.getvalue() + + prot = TBinaryProtocol.TBinaryProtocolAccelerated( + TTransport.TMemoryBuffer(slow_version_binary)) + c = o.__class__() + c.read(prot) + if c != o: + print("copy: ") + pprint(eval(repr(c))) + print("orig: ") + pprint(eval(repr(o))) + + prot = TBinaryProtocol.TBinaryProtocolAccelerated( + TTransport.TBufferedTransport( + TTransport.TMemoryBuffer(slow_version_binary))) + c = o.__class__() + c.read(prot) + if c != o: + print("copy: ") + pprint(eval(repr(c))) + print("orig: ") + pprint(eval(repr(o))) def do_test(): - check_write(hm) - check_read(HolyMoley()) - no_set = deepcopy(hm) - no_set.contain = set() - check_read(no_set) - check_write(rs) - check_read(rs) - check_write(rshuge) - check_read(rshuge) - check_write(my_zero) - check_read(my_zero) - check_read(Backwards(**{"first_tag2": 4, "second_tag1": 2})) - - # One case where the serialized form changes, but only superficially. - o = Backwards(**{"first_tag2": 4, "second_tag1": 2}) - trans_fast = TTransport.TMemoryBuffer() - trans_slow = TTransport.TMemoryBuffer() - prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) - prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) - - o.write(prot_fast) - o.write(prot_slow) - ORIG = trans_slow.getvalue() - MINE = trans_fast.getvalue() - assert id(ORIG) != id(MINE) - - prot = TBinaryProtocol.TBinaryProtocolAccelerated(TTransport.TMemoryBuffer()) - o.write(prot) - prot = TBinaryProtocol.TBinaryProtocol( - TTransport.TMemoryBuffer(prot.trans.getvalue())) - c = o.__class__() - c.read(prot) - if c != o: - print("copy: ") - pprint(eval(repr(c))) - print("orig: ") - pprint(eval(repr(o))) + check_write(hm) + check_read(HolyMoley()) + no_set = deepcopy(hm) + no_set.contain = set() + check_read(no_set) + check_write(rs) + check_read(rs) + check_write(rshuge) + check_read(rshuge) + check_write(my_zero) + check_read(my_zero) + check_read(Backwards(**{"first_tag2": 4, "second_tag1": 2})) + + # One case where the serialized form changes, but only superficially. + o = Backwards(**{"first_tag2": 4, "second_tag1": 2}) + trans_fast = TTransport.TMemoryBuffer() + trans_slow = TTransport.TMemoryBuffer() + prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) + prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) + + o.write(prot_fast) + o.write(prot_slow) + ORIG = trans_slow.getvalue() + MINE = trans_fast.getvalue() + assert id(ORIG) != id(MINE) + + prot = TBinaryProtocol.TBinaryProtocolAccelerated(TTransport.TMemoryBuffer()) + o.write(prot) + prot = TBinaryProtocol.TBinaryProtocol( + TTransport.TMemoryBuffer(prot.trans.getvalue())) + c = o.__class__() + c.read(prot) + if c != o: + print("copy: ") + pprint(eval(repr(c))) + print("orig: ") + pprint(eval(repr(o))) def do_benchmark(iters=5000): - setup = """ + setup = """ from __main__ import hm, rs, TDevNullTransport from thrift.protocol import TBinaryProtocol trans = TDevNullTransport() prot = TBinaryProtocol.TBinaryProtocol%s(trans) """ - setup_fast = setup % "Accelerated" - setup_slow = setup % "" + setup_fast = setup % "Accelerated" + setup_slow = setup % "" - print("Starting Benchmarks") + print("Starting Benchmarks") - print("HolyMoley Standard = %f" % - timeit.Timer('hm.write(prot)', setup_slow).timeit(number=iters)) - print("HolyMoley Acceler. = %f" % - timeit.Timer('hm.write(prot)', setup_fast).timeit(number=iters)) + print("HolyMoley Standard = %f" % + timeit.Timer('hm.write(prot)', setup_slow).timeit(number=iters)) + print("HolyMoley Acceler. = %f" % + timeit.Timer('hm.write(prot)', setup_fast).timeit(number=iters)) - print("FastStruct Standard = %f" % - timeit.Timer('rs.write(prot)', setup_slow).timeit(number=iters)) - print("FastStruct Acceler. = %f" % - timeit.Timer('rs.write(prot)', setup_fast).timeit(number=iters)) + print("FastStruct Standard = %f" % + timeit.Timer('rs.write(prot)', setup_slow).timeit(number=iters)) + print("FastStruct Acceler. = %f" % + timeit.Timer('rs.write(prot)', setup_fast).timeit(number=iters)) if __name__ == '__main__': - do_test() - do_benchmark() + do_test() + do_benchmark() diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py index d5ebd6a6d..98ead431d 100755 --- a/test/py/RunClientServer.py +++ b/test/py/RunClientServer.py @@ -37,13 +37,13 @@ DEFAULT_LIBDIR_GLOB = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*') DEFAULT_LIBDIR_PY3 = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib') SCRIPTS = [ - 'FastbinaryTest.py', - 'TestFrozen.py', - 'TSimpleJSONProtocolTest.py', - 'SerializationTest.py', - 'TestEof.py', - 'TestSyntax.py', - 'TestSocket.py', + 'FastbinaryTest.py', + 'TestFrozen.py', + 'TSimpleJSONProtocolTest.py', + 'SerializationTest.py', + 'TestEof.py', + 'TestSyntax.py', + 'TestSocket.py', ] FRAMED = ["TNonblockingServer"] SKIP_ZLIB = ['TNonblockingServer', 'THttpServer'] @@ -51,20 +51,20 @@ SKIP_SSL = ['TNonblockingServer', 'THttpServer'] EXTRA_DELAY = dict(TProcessPoolServer=5.5) PROTOS = [ - 'accel', - 'binary', - 'compact', - 'json', + 'accel', + 'binary', + 'compact', + 'json', ] SERVERS = [ - "TSimpleServer", - "TThreadedServer", - "TThreadPoolServer", - "TProcessPoolServer", - "TForkingServer", - "TNonblockingServer", - "THttpServer", + "TSimpleServer", + "TThreadedServer", + "TThreadPoolServer", + "TProcessPoolServer", + "TForkingServer", + "TNonblockingServer", + "THttpServer", ] @@ -73,246 +73,246 @@ def relfile(fname): def setup_pypath(libdir, gendir): - dirs = [libdir, gendir] - env = copy.deepcopy(os.environ) - pypath = env.get('PYTHONPATH', None) - if pypath: - dirs.append(pypath) - env['PYTHONPATH'] = ':'.join(dirs) - if gendir.endswith('gen-py-no_utf8strings'): - env['THRIFT_TEST_PY_NO_UTF8STRINGS'] = '1' - return env + dirs = [libdir, gendir] + env = copy.deepcopy(os.environ) + pypath = env.get('PYTHONPATH', None) + if pypath: + dirs.append(pypath) + env['PYTHONPATH'] = ':'.join(dirs) + if gendir.endswith('gen-py-no_utf8strings'): + env['THRIFT_TEST_PY_NO_UTF8STRINGS'] = '1' + return env def runScriptTest(libdir, genbase, genpydir, script): - env = setup_pypath(libdir, os.path.join(genbase, genpydir)) - script_args = [sys.executable, relfile(script)] - print('\nTesting script: %s\n----' % (' '.join(script_args))) - ret = subprocess.call(script_args, env=env) - if ret != 0: - print('*** FAILED ***', file=sys.stderr) - print('LIBDIR: %s' % libdir, file=sys.stderr) - print('PY_GEN: %s' % genpydir, file=sys.stderr) - print('SCRIPT: %s' % script, file=sys.stderr) - raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args))) + env = setup_pypath(libdir, os.path.join(genbase, genpydir)) + script_args = [sys.executable, relfile(script)] + print('\nTesting script: %s\n----' % (' '.join(script_args))) + ret = subprocess.call(script_args, env=env) + if ret != 0: + print('*** FAILED ***', file=sys.stderr) + print('LIBDIR: %s' % libdir, file=sys.stderr) + print('PY_GEN: %s' % genpydir, file=sys.stderr) + print('SCRIPT: %s' % script, file=sys.stderr) + raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args))) def runServiceTest(libdir, genbase, genpydir, server_class, proto, port, use_zlib, use_ssl, verbose): - env = setup_pypath(libdir, os.path.join(genbase, genpydir)) - # Build command line arguments - server_args = [sys.executable, relfile('TestServer.py')] - cli_args = [sys.executable, relfile('TestClient.py')] - for which in (server_args, cli_args): - which.append('--protocol=%s' % proto) # accel, binary, compact or json - which.append('--port=%d' % port) # default to 9090 - if use_zlib: - which.append('--zlib') - if use_ssl: - which.append('--ssl') - if verbose == 0: - which.append('-q') - if verbose == 2: - which.append('-v') - # server-specific option to select server class - server_args.append(server_class) - # client-specific cmdline options - if server_class in FRAMED: - cli_args.append('--transport=framed') - else: - cli_args.append('--transport=buffered') - if server_class == 'THttpServer': - cli_args.append('--http=/') - if verbose > 0: - print('Testing server %s: %s' % (server_class, ' '.join(server_args))) - serverproc = subprocess.Popen(server_args, env=env) - - def ensureServerAlive(): - if serverproc.poll() is not None: - print(('FAIL: Server process (%s) failed with retcode %d') - % (' '.join(server_args), serverproc.returncode)) - raise Exception('Server subprocess %s died, args: %s' - % (server_class, ' '.join(server_args))) - - # Wait for the server to start accepting connections on the given port. - sock = socket.socket() - sleep_time = 0.1 # Seconds - max_attempts = 100 - try: - attempt = 0 - while sock.connect_ex(('127.0.0.1', port)) != 0: - attempt += 1 - if attempt >= max_attempts: - raise Exception("TestServer not ready on port %d after %.2f seconds" - % (port, sleep_time * attempt)) - ensureServerAlive() - time.sleep(sleep_time) - finally: - sock.close() - - try: + env = setup_pypath(libdir, os.path.join(genbase, genpydir)) + # Build command line arguments + server_args = [sys.executable, relfile('TestServer.py')] + cli_args = [sys.executable, relfile('TestClient.py')] + for which in (server_args, cli_args): + which.append('--protocol=%s' % proto) # accel, binary, compact or json + which.append('--port=%d' % port) # default to 9090 + if use_zlib: + which.append('--zlib') + if use_ssl: + which.append('--ssl') + if verbose == 0: + which.append('-q') + if verbose == 2: + which.append('-v') + # server-specific option to select server class + server_args.append(server_class) + # client-specific cmdline options + if server_class in FRAMED: + cli_args.append('--transport=framed') + else: + cli_args.append('--transport=buffered') + if server_class == 'THttpServer': + cli_args.append('--http=/') if verbose > 0: - print('Testing client: %s' % (' '.join(cli_args))) - ret = subprocess.call(cli_args, env=env) - if ret != 0: - print('*** FAILED ***', file=sys.stderr) - print('LIBDIR: %s' % libdir, file=sys.stderr) - print('PY_GEN: %s' % genpydir, file=sys.stderr) - raise Exception("Client subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(cli_args))) - finally: - # check that server didn't die - ensureServerAlive() - extra_sleep = EXTRA_DELAY.get(server_class, 0) - if extra_sleep > 0 and verbose > 0: - print('Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child' - 'processes to terminate via alarm' - % (server_class, proto, use_zlib, use_ssl, extra_sleep)) - time.sleep(extra_sleep) - os.kill(serverproc.pid, signal.SIGKILL) - serverproc.wait() + print('Testing server %s: %s' % (server_class, ' '.join(server_args))) + serverproc = subprocess.Popen(server_args, env=env) + + def ensureServerAlive(): + if serverproc.poll() is not None: + print(('FAIL: Server process (%s) failed with retcode %d') + % (' '.join(server_args), serverproc.returncode)) + raise Exception('Server subprocess %s died, args: %s' + % (server_class, ' '.join(server_args))) + + # Wait for the server to start accepting connections on the given port. + sock = socket.socket() + sleep_time = 0.1 # Seconds + max_attempts = 100 + try: + attempt = 0 + while sock.connect_ex(('127.0.0.1', port)) != 0: + attempt += 1 + if attempt >= max_attempts: + raise Exception("TestServer not ready on port %d after %.2f seconds" + % (port, sleep_time * attempt)) + ensureServerAlive() + time.sleep(sleep_time) + finally: + sock.close() + + try: + if verbose > 0: + print('Testing client: %s' % (' '.join(cli_args))) + ret = subprocess.call(cli_args, env=env) + if ret != 0: + print('*** FAILED ***', file=sys.stderr) + print('LIBDIR: %s' % libdir, file=sys.stderr) + print('PY_GEN: %s' % genpydir, file=sys.stderr) + raise Exception("Client subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(cli_args))) + finally: + # check that server didn't die + ensureServerAlive() + extra_sleep = EXTRA_DELAY.get(server_class, 0) + if extra_sleep > 0 and verbose > 0: + print('Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child' + 'processes to terminate via alarm' + % (server_class, proto, use_zlib, use_ssl, extra_sleep)) + time.sleep(extra_sleep) + os.kill(serverproc.pid, signal.SIGKILL) + serverproc.wait() class TestCases(object): - def __init__(self, genbase, libdir, port, gendirs, servers, verbose): - self.genbase = genbase - self.libdir = libdir - self.port = port - self.verbose = verbose - self.gendirs = gendirs - self.servers = servers - - def default_conf(self): - return { - 'gendir': self.gendirs[0], - 'server': self.servers[0], - 'proto': PROTOS[0], - 'zlib': False, - 'ssl': False, - } - - def run(self, conf, test_count): - with_zlib = conf['zlib'] - with_ssl = conf['ssl'] - try_server = conf['server'] - try_proto = conf['proto'] - genpydir = conf['gendir'] - # skip any servers that don't work with the Zlib transport - if with_zlib and try_server in SKIP_ZLIB: - return False - # skip any servers that don't work with SSL - if with_ssl and try_server in SKIP_SSL: - return False - if self.verbose > 0: - print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s' - % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl)) - runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl, self.verbose) - if self.verbose > 0: - print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' - % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count)) - return True - - def test_feature(self, name, values): - test_count = 0 - conf = self.default_conf() - for try_server in values: - conf[name] = try_server - if self.run(conf, test_count): - test_count += 1 - return test_count - - def run_all_tests(self): - test_count = 0 - for try_server in self.servers: - for genpydir in self.gendirs: - for try_proto in PROTOS: - for with_zlib in (False, True): - # skip any servers that don't work with the Zlib transport - if with_zlib and try_server in SKIP_ZLIB: - continue - for with_ssl in (False, True): - # skip any servers that don't work with SSL - if with_ssl and try_server in SKIP_SSL: - continue - test_count += 1 - if self.verbose > 0: - print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s' - % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl)) - runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl) - if self.verbose > 0: - print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' - % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count)) - return test_count + def __init__(self, genbase, libdir, port, gendirs, servers, verbose): + self.genbase = genbase + self.libdir = libdir + self.port = port + self.verbose = verbose + self.gendirs = gendirs + self.servers = servers + + def default_conf(self): + return { + 'gendir': self.gendirs[0], + 'server': self.servers[0], + 'proto': PROTOS[0], + 'zlib': False, + 'ssl': False, + } + + def run(self, conf, test_count): + with_zlib = conf['zlib'] + with_ssl = conf['ssl'] + try_server = conf['server'] + try_proto = conf['proto'] + genpydir = conf['gendir'] + # skip any servers that don't work with the Zlib transport + if with_zlib and try_server in SKIP_ZLIB: + return False + # skip any servers that don't work with SSL + if with_ssl and try_server in SKIP_SSL: + return False + if self.verbose > 0: + print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s' + % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl)) + runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl, self.verbose) + if self.verbose > 0: + print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' + % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count)) + return True + + def test_feature(self, name, values): + test_count = 0 + conf = self.default_conf() + for try_server in values: + conf[name] = try_server + if self.run(conf, test_count): + test_count += 1 + return test_count + + def run_all_tests(self): + test_count = 0 + for try_server in self.servers: + for genpydir in self.gendirs: + for try_proto in PROTOS: + for with_zlib in (False, True): + # skip any servers that don't work with the Zlib transport + if with_zlib and try_server in SKIP_ZLIB: + continue + for with_ssl in (False, True): + # skip any servers that don't work with SSL + if with_ssl and try_server in SKIP_SSL: + continue + test_count += 1 + if self.verbose > 0: + print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s' + % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl)) + runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl) + if self.verbose > 0: + print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' + % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count)) + return test_count def default_libdir(): - if sys.version_info[0] == 2: - return glob.glob(DEFAULT_LIBDIR_GLOB)[0] - else: - return DEFAULT_LIBDIR_PY3 + if sys.version_info[0] == 2: + return glob.glob(DEFAULT_LIBDIR_GLOB)[0] + else: + return DEFAULT_LIBDIR_PY3 def main(): - parser = OptionParser() - parser.add_option('--all', action="store_true", dest='all') - parser.add_option('--genpydirs', type='string', dest='genpydirs', - default='default,slots,oldstyle,no_utf8strings,dynamic,dynamicslots', - help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests') - parser.add_option("--port", type="int", dest="port", default=9090, - help="port number for server to listen on") - parser.add_option('-v', '--verbose', action="store_const", - dest="verbose", const=2, - help="verbose output") - parser.add_option('-q', '--quiet', action="store_const", - dest="verbose", const=0, - help="minimal output") - parser.add_option('-L', '--libdir', dest="libdir", default=default_libdir(), - help="directory path that contains Thrift Python library") - parser.add_option('--gen-base', dest="gen_base", default=SCRIPT_DIR, - help="directory path that contains Thrift Python library") - parser.set_defaults(verbose=1) - options, args = parser.parse_args() - - generated_dirs = [] - for gp_dir in options.genpydirs.split(','): - generated_dirs.append('gen-py-%s' % (gp_dir)) - - # commandline permits a single class name to be specified to override SERVERS=[...] - servers = SERVERS - if len(args) == 1: - if args[0] in SERVERS: - servers = args + parser = OptionParser() + parser.add_option('--all', action="store_true", dest='all') + parser.add_option('--genpydirs', type='string', dest='genpydirs', + default='default,slots,oldstyle,no_utf8strings,dynamic,dynamicslots', + help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests') + parser.add_option("--port", type="int", dest="port", default=9090, + help="port number for server to listen on") + parser.add_option('-v', '--verbose', action="store_const", + dest="verbose", const=2, + help="verbose output") + parser.add_option('-q', '--quiet', action="store_const", + dest="verbose", const=0, + help="minimal output") + parser.add_option('-L', '--libdir', dest="libdir", default=default_libdir(), + help="directory path that contains Thrift Python library") + parser.add_option('--gen-base', dest="gen_base", default=SCRIPT_DIR, + help="directory path that contains Thrift Python library") + parser.set_defaults(verbose=1) + options, args = parser.parse_args() + + generated_dirs = [] + for gp_dir in options.genpydirs.split(','): + generated_dirs.append('gen-py-%s' % (gp_dir)) + + # commandline permits a single class name to be specified to override SERVERS=[...] + servers = SERVERS + if len(args) == 1: + if args[0] in SERVERS: + servers = args + else: + print('Unavailable server type "%s", please choose one of: %s' % (args[0], servers)) + sys.exit(0) + + tests = TestCases(options.gen_base, options.libdir, options.port, generated_dirs, servers, options.verbose) + + # run tests without a client/server first + print('----------------') + print(' Executing individual test scripts with various generated code directories') + print(' Directories to be tested: ' + ', '.join(generated_dirs)) + print(' Scripts to be tested: ' + ', '.join(SCRIPTS)) + print('----------------') + for genpydir in generated_dirs: + for script in SCRIPTS: + runScriptTest(options.libdir, options.gen_base, genpydir, script) + + print('----------------') + print(' Executing Client/Server tests with various generated code directories') + print(' Servers to be tested: ' + ', '.join(servers)) + print(' Directories to be tested: ' + ', '.join(generated_dirs)) + print(' Protocols to be tested: ' + ', '.join(PROTOS)) + print(' Options to be tested: ZLIB(yes/no), SSL(yes/no)') + print('----------------') + + if options.all: + tests.run_all_tests() else: - print('Unavailable server type "%s", please choose one of: %s' % (args[0], servers)) - sys.exit(0) - - tests = TestCases(options.gen_base, options.libdir, options.port, generated_dirs, servers, options.verbose) - - # run tests without a client/server first - print('----------------') - print(' Executing individual test scripts with various generated code directories') - print(' Directories to be tested: ' + ', '.join(generated_dirs)) - print(' Scripts to be tested: ' + ', '.join(SCRIPTS)) - print('----------------') - for genpydir in generated_dirs: - for script in SCRIPTS: - runScriptTest(options.libdir, options.gen_base, genpydir, script) - - print('----------------') - print(' Executing Client/Server tests with various generated code directories') - print(' Servers to be tested: ' + ', '.join(servers)) - print(' Directories to be tested: ' + ', '.join(generated_dirs)) - print(' Protocols to be tested: ' + ', '.join(PROTOS)) - print(' Options to be tested: ZLIB(yes/no), SSL(yes/no)') - print('----------------') - - if options.all: - tests.run_all_tests() - else: - tests.test_feature('gendir', generated_dirs) - tests.test_feature('server', servers) - tests.test_feature('proto', PROTOS) - tests.test_feature('zlib', [False, True]) - tests.test_feature('ssl', [False, True]) + tests.test_feature('gendir', generated_dirs) + tests.test_feature('server', servers) + tests.test_feature('proto', PROTOS) + tests.test_feature('zlib', [False, True]) + tests.test_feature('ssl', [False, True]) if __name__ == '__main__': - sys.exit(main()) + sys.exit(main()) diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py index d4755cf2a..65a149599 100755 --- a/test/py/SerializationTest.py +++ b/test/py/SerializationTest.py @@ -30,341 +30,342 @@ import unittest class AbstractTest(unittest.TestCase): - def setUp(self): - self.v1obj = VersioningTestV1( - begin_in_both=12345, - old_string='aaa', - end_in_both=54321, - ) - - self.v2obj = VersioningTestV2( - begin_in_both=12345, - newint=1, - newbyte=2, - newshort=3, - newlong=4, - newdouble=5.0, - newstruct=Bonk(message="Hello!", type=123), - newlist=[7,8,9], - newset=set([42,1,8]), - newmap={1:2,2:3}, - newstring="Hola!", - end_in_both=54321, - ) - - self.bools = Bools(im_true=True, im_false=False) - self.bools_flipped = Bools(im_true=False, im_false=True) - - self.large_deltas = LargeDeltas ( - b1=self.bools, - b10=self.bools_flipped, - b100=self.bools, - check_true=True, - b1000=self.bools_flipped, - check_false=False, - vertwo2000=VersioningTestV2(newstruct=Bonk(message='World!', type=314)), - a_set2500=set(['lazy', 'brown', 'cow']), - vertwo3000=VersioningTestV2(newset=set([2, 3, 5, 7, 11])), - big_numbers=[2**8, 2**16, 2**31-1, -(2**31-1)] - ) - - self.compact_struct = CompactProtoTestStruct( - a_byte = 127, - a_i16=32000, - a_i32=1000000000, - a_i64=0xffffffffff, - a_double=5.6789, - a_string="my string", - true_field=True, - false_field=False, - empty_struct_field=Empty(), - byte_list=[-127, -1, 0, 1, 127], - i16_list=[-1, 0, 1, 0x7fff], - i32_list= [-1, 0, 0xff, 0xffff, 0xffffff, 0x7fffffff], - i64_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff], - double_list=[0.1, 0.2, 0.3], - string_list=["first", "second", "third"], - boolean_list=[True, True, True, False, False, False], - struct_list=[Empty(), Empty()], - byte_set=set([-127, -1, 0, 1, 127]), - i16_set=set([-1, 0, 1, 0x7fff]), - i32_set=set([1, 2, 3]), - i64_set=set([-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff]), - double_set=set([0.1, 0.2, 0.3]), - string_set=set(["first", "second", "third"]), - boolean_set=set([True, False]), - #struct_set=set([Empty()]), # unhashable instance - byte_byte_map={1 : 2}, - i16_byte_map={1 : 1, -1 : 1, 0x7fff : 1}, - i32_byte_map={1 : 1, -1 : 1, 0x7fffffff : 1}, - i64_byte_map={0 : 1, 1 : 1, -1 : 1, 0x7fffffffffffffff : 1}, - double_byte_map={-1.1 : 1, 1.1 : 1}, - string_byte_map={"first" : 1, "second" : 2, "third" : 3, "" : 0}, - boolean_byte_map={True : 1, False: 0}, - byte_i16_map={1 : 1, 2 : -1, 3 : 0x7fff}, - byte_i32_map={1 : 1, 2 : -1, 3 : 0x7fffffff}, - byte_i64_map={1 : 1, 2 : -1, 3 : 0x7fffffffffffffff}, - byte_double_map={1 : 0.1, 2 : -0.1, 3 : 1000000.1}, - byte_string_map={1 : "", 2 : "blah", 3 : "loooooooooooooong string"}, - byte_boolean_map={1 : True, 2 : False}, - #list_byte_map # unhashable - #set_byte_map={set([1, 2, 3]) : 1, set([0, 1]) : 2, set([]) : 0}, # unhashable - #map_byte_map # unhashable - byte_map_map={0 : {}, 1 : {1 : 1}, 2 : {1 : 1, 2 : 2}}, - byte_set_map={0 : set([]), 1 : set([1]), 2 : set([1, 2])}, - byte_list_map={0 : [], 1 : [1], 2 : [1, 2]}, - ) - - self.nested_lists_i32x2 = NestedListsI32x2( - [ - [ 1, 1, 2 ], - [ 2, 7, 9 ], - [ 3, 5, 8 ] - ] - ) - - self.nested_lists_i32x3 = NestedListsI32x3( - [ - [ - [ 2, 7, 9 ], - [ 3, 5, 8 ] - ], - [ - [ 1, 1, 2 ], - [ 1, 4, 9 ] - ] - ] - ) - - self.nested_mixedx2 = NestedMixedx2( int_set_list=[ - set([1,2,3]), - set([1,4,9]), - set([1,2,3,5,8,13,21]), - set([-1, 0, 1]) - ], - # note, the sets below are sets of chars, since the strings are iterated - map_int_strset={ 10:set('abc'), 20:set('def'), 30:set('GHI') }, - map_int_strset_list=[ - { 10:set('abc'), 20:set('def'), 30:set('GHI') }, - { 100:set('lmn'), 200:set('opq'), 300:set('RST') }, - { 1000:set('uvw'), 2000:set('wxy'), 3000:set('XYZ') } - ] - ) - - self.nested_lists_bonk = NestedListsBonk( - [ - [ - [ - Bonk(message='inner A first', type=1), - Bonk(message='inner A second', type=1) - ], - [ - Bonk(message='inner B first', type=2), - Bonk(message='inner B second', type=2) - ] - ] - ] - ) - - self.list_bonks = ListBonks( - [ - Bonk(message='inner A', type=1), - Bonk(message='inner B', type=2), - Bonk(message='inner C', type=0) - ] - ) - - def _serialize(self, obj): - trans = TTransport.TMemoryBuffer() - prot = self.protocol_factory.getProtocol(trans) - obj.write(prot) - return trans.getvalue() - - def _deserialize(self, objtype, data): - prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) - ret = objtype() - ret.read(prot) - return ret - - def testForwards(self): - obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj)) - self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both) - self.assertEquals(obj.end_in_both, self.v1obj.end_in_both) - - def testBackwards(self): - obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj)) - self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both) - self.assertEquals(obj.end_in_both, self.v2obj.end_in_both) - - def testSerializeV1(self): - obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj)) - self.assertEquals(obj, self.v1obj) - - def testSerializeV2(self): - obj = self._deserialize(VersioningTestV2, self._serialize(self.v2obj)) - self.assertEquals(obj, self.v2obj) - - def testBools(self): - self.assertNotEquals(self.bools, self.bools_flipped) - self.assertNotEquals(self.bools, self.v1obj) - obj = self._deserialize(Bools, self._serialize(self.bools)) - self.assertEquals(obj, self.bools) - obj = self._deserialize(Bools, self._serialize(self.bools_flipped)) - self.assertEquals(obj, self.bools_flipped) - rep = repr(self.bools) - self.assertTrue(len(rep) > 0) - - def testLargeDeltas(self): - # test large field deltas (meaningful in CompactProto only) - obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas)) - self.assertEquals(obj, self.large_deltas) - rep = repr(self.large_deltas) - self.assertTrue(len(rep) > 0) - - def testNestedListsI32x2(self): - obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2)) - self.assertEquals(obj, self.nested_lists_i32x2) - rep = repr(self.nested_lists_i32x2) - self.assertTrue(len(rep) > 0) - - def testNestedListsI32x3(self): - obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3)) - self.assertEquals(obj, self.nested_lists_i32x3) - rep = repr(self.nested_lists_i32x3) - self.assertTrue(len(rep) > 0) - - def testNestedMixedx2(self): - obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2)) - self.assertEquals(obj, self.nested_mixedx2) - rep = repr(self.nested_mixedx2) - self.assertTrue(len(rep) > 0) - - def testNestedListsBonk(self): - obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk)) - self.assertEquals(obj, self.nested_lists_bonk) - rep = repr(self.nested_lists_bonk) - self.assertTrue(len(rep) > 0) - - def testListBonks(self): - obj = self._deserialize(ListBonks, self._serialize(self.list_bonks)) - self.assertEquals(obj, self.list_bonks) - rep = repr(self.list_bonks) - self.assertTrue(len(rep) > 0) - - def testCompactStruct(self): - # test large field deltas (meaningful in CompactProto only) - obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct)) - self.assertEquals(obj, self.compact_struct) - rep = repr(self.compact_struct) - self.assertTrue(len(rep) > 0) - - def testIntegerLimits(self): - if (sys.version_info[0] == 2 and sys.version_info[1] <= 6): - print('Skipping testIntegerLimits for Python 2.6') - return - bad_values = [CompactProtoTestStruct(a_byte=128), CompactProtoTestStruct(a_byte=-129), - CompactProtoTestStruct(a_i16=32768), CompactProtoTestStruct(a_i16=-32769), - CompactProtoTestStruct(a_i32=2147483648), CompactProtoTestStruct(a_i32=-2147483649), - CompactProtoTestStruct(a_i64=9223372036854775808), CompactProtoTestStruct(a_i64=-9223372036854775809) + def setUp(self): + self.v1obj = VersioningTestV1( + begin_in_both=12345, + old_string='aaa', + end_in_both=54321, + ) + + self.v2obj = VersioningTestV2( + begin_in_both=12345, + newint=1, + newbyte=2, + newshort=3, + newlong=4, + newdouble=5.0, + newstruct=Bonk(message="Hello!", type=123), + newlist=[7, 8, 9], + newset=set([42, 1, 8]), + newmap={1: 2, 2: 3}, + newstring="Hola!", + end_in_both=54321, + ) + + self.bools = Bools(im_true=True, im_false=False) + self.bools_flipped = Bools(im_true=False, im_false=True) + + self.large_deltas = LargeDeltas( + b1=self.bools, + b10=self.bools_flipped, + b100=self.bools, + check_true=True, + b1000=self.bools_flipped, + check_false=False, + vertwo2000=VersioningTestV2(newstruct=Bonk(message='World!', type=314)), + a_set2500=set(['lazy', 'brown', 'cow']), + vertwo3000=VersioningTestV2(newset=set([2, 3, 5, 7, 11])), + big_numbers=[2**8, 2**16, 2**31 - 1, -(2**31 - 1)] + ) + + self.compact_struct = CompactProtoTestStruct( + a_byte=127, + a_i16=32000, + a_i32=1000000000, + a_i64=0xffffffffff, + a_double=5.6789, + a_string="my string", + true_field=True, + false_field=False, + empty_struct_field=Empty(), + byte_list=[-127, -1, 0, 1, 127], + i16_list=[-1, 0, 1, 0x7fff], + i32_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0x7fffffff], + i64_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff], + double_list=[0.1, 0.2, 0.3], + string_list=["first", "second", "third"], + boolean_list=[True, True, True, False, False, False], + struct_list=[Empty(), Empty()], + byte_set=set([-127, -1, 0, 1, 127]), + i16_set=set([-1, 0, 1, 0x7fff]), + i32_set=set([1, 2, 3]), + i64_set=set([-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff]), + double_set=set([0.1, 0.2, 0.3]), + string_set=set(["first", "second", "third"]), + boolean_set=set([True, False]), + # struct_set=set([Empty()]), # unhashable instance + byte_byte_map={1: 2}, + i16_byte_map={1: 1, -1: 1, 0x7fff: 1}, + i32_byte_map={1: 1, -1: 1, 0x7fffffff: 1}, + i64_byte_map={0: 1, 1: 1, -1: 1, 0x7fffffffffffffff: 1}, + double_byte_map={-1.1: 1, 1.1: 1}, + string_byte_map={"first": 1, "second": 2, "third": 3, "": 0}, + boolean_byte_map={True: 1, False: 0}, + byte_i16_map={1: 1, 2: -1, 3: 0x7fff}, + byte_i32_map={1: 1, 2: -1, 3: 0x7fffffff}, + byte_i64_map={1: 1, 2: -1, 3: 0x7fffffffffffffff}, + byte_double_map={1: 0.1, 2: -0.1, 3: 1000000.1}, + byte_string_map={1: "", 2: "blah", 3: "loooooooooooooong string"}, + byte_boolean_map={1: True, 2: False}, + # list_byte_map # unhashable + # set_byte_map={set([1, 2, 3]) : 1, set([0, 1]) : 2, set([]) : 0}, # unhashable + # map_byte_map # unhashable + byte_map_map={0: {}, 1: {1: 1}, 2: {1: 1, 2: 2}}, + byte_set_map={0: set([]), 1: set([1]), 2: set([1, 2])}, + byte_list_map={0: [], 1: [1], 2: [1, 2]}, + ) + + self.nested_lists_i32x2 = NestedListsI32x2( + [ + [1, 1, 2], + [2, 7, 9], + [3, 5, 8] + ] + ) + + self.nested_lists_i32x3 = NestedListsI32x3( + [ + [ + [2, 7, 9], + [3, 5, 8] + ], + [ + [1, 1, 2], + [1, 4, 9] ] - - for value in bad_values: - self.assertRaises(Exception, self._serialize, value) + ] + ) + + self.nested_mixedx2 = NestedMixedx2(int_set_list=[ + set([1, 2, 3]), + set([1, 4, 9]), + set([1, 2, 3, 5, 8, 13, 21]), + set([-1, 0, 1]) + ], + # note, the sets below are sets of chars, since the strings are iterated + map_int_strset={10: set('abc'), 20: set('def'), 30: set('GHI')}, + map_int_strset_list=[ + {10: set('abc'), 20: set('def'), 30: set('GHI')}, + {100: set('lmn'), 200: set('opq'), 300: set('RST')}, + {1000: set('uvw'), 2000: set('wxy'), 3000: set('XYZ')} + ] + ) + + self.nested_lists_bonk = NestedListsBonk( + [ + [ + [ + Bonk(message='inner A first', type=1), + Bonk(message='inner A second', type=1) + ], + [ + Bonk(message='inner B first', type=2), + Bonk(message='inner B second', type=2) + ] + ] + ] + ) + + self.list_bonks = ListBonks( + [ + Bonk(message='inner A', type=1), + Bonk(message='inner B', type=2), + Bonk(message='inner C', type=0) + ] + ) + + def _serialize(self, obj): + trans = TTransport.TMemoryBuffer() + prot = self.protocol_factory.getProtocol(trans) + obj.write(prot) + return trans.getvalue() + + def _deserialize(self, objtype, data): + prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) + ret = objtype() + ret.read(prot) + return ret + + def testForwards(self): + obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj)) + self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both) + self.assertEquals(obj.end_in_both, self.v1obj.end_in_both) + + def testBackwards(self): + obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj)) + self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both) + self.assertEquals(obj.end_in_both, self.v2obj.end_in_both) + + def testSerializeV1(self): + obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj)) + self.assertEquals(obj, self.v1obj) + + def testSerializeV2(self): + obj = self._deserialize(VersioningTestV2, self._serialize(self.v2obj)) + self.assertEquals(obj, self.v2obj) + + def testBools(self): + self.assertNotEquals(self.bools, self.bools_flipped) + self.assertNotEquals(self.bools, self.v1obj) + obj = self._deserialize(Bools, self._serialize(self.bools)) + self.assertEquals(obj, self.bools) + obj = self._deserialize(Bools, self._serialize(self.bools_flipped)) + self.assertEquals(obj, self.bools_flipped) + rep = repr(self.bools) + self.assertTrue(len(rep) > 0) + + def testLargeDeltas(self): + # test large field deltas (meaningful in CompactProto only) + obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas)) + self.assertEquals(obj, self.large_deltas) + rep = repr(self.large_deltas) + self.assertTrue(len(rep) > 0) + + def testNestedListsI32x2(self): + obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2)) + self.assertEquals(obj, self.nested_lists_i32x2) + rep = repr(self.nested_lists_i32x2) + self.assertTrue(len(rep) > 0) + + def testNestedListsI32x3(self): + obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3)) + self.assertEquals(obj, self.nested_lists_i32x3) + rep = repr(self.nested_lists_i32x3) + self.assertTrue(len(rep) > 0) + + def testNestedMixedx2(self): + obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2)) + self.assertEquals(obj, self.nested_mixedx2) + rep = repr(self.nested_mixedx2) + self.assertTrue(len(rep) > 0) + + def testNestedListsBonk(self): + obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk)) + self.assertEquals(obj, self.nested_lists_bonk) + rep = repr(self.nested_lists_bonk) + self.assertTrue(len(rep) > 0) + + def testListBonks(self): + obj = self._deserialize(ListBonks, self._serialize(self.list_bonks)) + self.assertEquals(obj, self.list_bonks) + rep = repr(self.list_bonks) + self.assertTrue(len(rep) > 0) + + def testCompactStruct(self): + # test large field deltas (meaningful in CompactProto only) + obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct)) + self.assertEquals(obj, self.compact_struct) + rep = repr(self.compact_struct) + self.assertTrue(len(rep) > 0) + + def testIntegerLimits(self): + if (sys.version_info[0] == 2 and sys.version_info[1] <= 6): + print('Skipping testIntegerLimits for Python 2.6') + return + bad_values = [CompactProtoTestStruct(a_byte=128), CompactProtoTestStruct(a_byte=-129), + CompactProtoTestStruct(a_i16=32768), CompactProtoTestStruct(a_i16=-32769), + CompactProtoTestStruct(a_i32=2147483648), CompactProtoTestStruct(a_i32=-2147483649), + CompactProtoTestStruct(a_i64=9223372036854775808), CompactProtoTestStruct(a_i64=-9223372036854775809) + ] + + for value in bad_values: + self.assertRaises(Exception, self._serialize, value) class NormalBinaryTest(AbstractTest): - protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() + protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() class AcceleratedBinaryTest(AbstractTest): - protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory() + protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory() class CompactProtocolTest(AbstractTest): - protocol_factory = TCompactProtocol.TCompactProtocolFactory() + protocol_factory = TCompactProtocol.TCompactProtocolFactory() class JSONProtocolTest(AbstractTest): - protocol_factory = TJSONProtocol.TJSONProtocolFactory() + protocol_factory = TJSONProtocol.TJSONProtocolFactory() class AcceleratedFramedTest(unittest.TestCase): - def testSplit(self): - """Test FramedTransport and BinaryProtocolAccelerated + def testSplit(self): + """Test FramedTransport and BinaryProtocolAccelerated + + Tests that TBinaryProtocolAccelerated and TFramedTransport + play nicely together when a read spans a frame""" + + protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory() + bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z") + 1)) + + databuf = TTransport.TMemoryBuffer() + prot = protocol_factory.getProtocol(databuf) + prot.writeI32(42) + prot.writeString(bigstring) + prot.writeI16(24) + data = databuf.getvalue() + cutpoint = len(data) // 2 + parts = [data[:cutpoint], data[cutpoint:]] + + framed_buffer = TTransport.TMemoryBuffer() + framed_writer = TTransport.TFramedTransport(framed_buffer) + for part in parts: + framed_writer.write(part) + framed_writer.flush() + self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8) + + # Recreate framed_buffer so we can read from it. + framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue()) + framed_reader = TTransport.TFramedTransport(framed_buffer) + prot = protocol_factory.getProtocol(framed_reader) + self.assertEqual(prot.readI32(), 42) + self.assertEqual(prot.readString(), bigstring) + self.assertEqual(prot.readI16(), 24) - Tests that TBinaryProtocolAccelerated and TFramedTransport - play nicely together when a read spans a frame""" - - protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory() - bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z")+1)) - - databuf = TTransport.TMemoryBuffer() - prot = protocol_factory.getProtocol(databuf) - prot.writeI32(42) - prot.writeString(bigstring) - prot.writeI16(24) - data = databuf.getvalue() - cutpoint = len(data) // 2 - parts = [ data[:cutpoint], data[cutpoint:] ] - - framed_buffer = TTransport.TMemoryBuffer() - framed_writer = TTransport.TFramedTransport(framed_buffer) - for part in parts: - framed_writer.write(part) - framed_writer.flush() - self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8) - - # Recreate framed_buffer so we can read from it. - framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue()) - framed_reader = TTransport.TFramedTransport(framed_buffer) - prot = protocol_factory.getProtocol(framed_reader) - self.assertEqual(prot.readI32(), 42) - self.assertEqual(prot.readString(), bigstring) - self.assertEqual(prot.readI16(), 24) class SerializersTest(unittest.TestCase): - def testSerializeThenDeserialize(self): - obj = Xtruct2(i32_thing=1, - struct_thing=Xtruct(string_thing="foo")) + def testSerializeThenDeserialize(self): + obj = Xtruct2(i32_thing=1, + struct_thing=Xtruct(string_thing="foo")) - s1 = serialize(obj) - for i in range(10): - self.assertEquals(s1, serialize(obj)) - objcopy = Xtruct2() - deserialize(objcopy, serialize(obj)) - self.assertEquals(obj, objcopy) + s1 = serialize(obj) + for i in range(10): + self.assertEquals(s1, serialize(obj)) + objcopy = Xtruct2() + deserialize(objcopy, serialize(obj)) + self.assertEquals(obj, objcopy) - obj = Xtruct(string_thing="bar") - objcopy = Xtruct() - deserialize(objcopy, serialize(obj)) - self.assertEquals(obj, objcopy) + obj = Xtruct(string_thing="bar") + objcopy = Xtruct() + deserialize(objcopy, serialize(obj)) + self.assertEquals(obj, objcopy) - # test booleans - obj = Bools(im_true=True, im_false=False) - objcopy = Bools() - deserialize(objcopy, serialize(obj)) - self.assertEquals(obj, objcopy) + # test booleans + obj = Bools(im_true=True, im_false=False) + objcopy = Bools() + deserialize(objcopy, serialize(obj)) + self.assertEquals(obj, objcopy) - # test enums - for num, name in Numberz._VALUES_TO_NAMES.items(): - obj = Bonk(message='enum Numberz value %d is string %s' % (num, name), type=num) - objcopy = Bonk() - deserialize(objcopy, serialize(obj)) - self.assertEquals(obj, objcopy) + # test enums + for num, name in Numberz._VALUES_TO_NAMES.items(): + obj = Bonk(message='enum Numberz value %d is string %s' % (num, name), type=num) + objcopy = Bonk() + deserialize(objcopy, serialize(obj)) + self.assertEquals(obj, objcopy) def suite(): - suite = unittest.TestSuite() - loader = unittest.TestLoader() + suite = unittest.TestSuite() + loader = unittest.TestLoader() - suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest)) - suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest)) - suite.addTest(loader.loadTestsFromTestCase(CompactProtocolTest)) - suite.addTest(loader.loadTestsFromTestCase(JSONProtocolTest)) - suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest)) - suite.addTest(loader.loadTestsFromTestCase(SerializersTest)) - return suite + suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest)) + suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest)) + suite.addTest(loader.loadTestsFromTestCase(CompactProtocolTest)) + suite.addTest(loader.loadTestsFromTestCase(JSONProtocolTest)) + suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest)) + suite.addTest(loader.loadTestsFromTestCase(SerializersTest)) + return suite if __name__ == "__main__": - unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) + unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/test/py/TSimpleJSONProtocolTest.py b/test/py/TSimpleJSONProtocolTest.py index 1ed8c1574..72987602b 100644 --- a/test/py/TSimpleJSONProtocolTest.py +++ b/test/py/TSimpleJSONProtocolTest.py @@ -28,81 +28,81 @@ import unittest class SimpleJSONProtocolTest(unittest.TestCase): - protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory() - - def _assertDictEqual(self, a, b, msg=None): - if hasattr(self, 'assertDictEqual'): - # assertDictEqual only in Python 2.7. Depends on your machine. - self.assertDictEqual(a, b, msg) - return - - # Substitute implementation not as good as unittest library's - self.assertEquals(len(a), len(b), msg) - for k, v in a.iteritems(): - self.assertTrue(k in b, msg) - self.assertEquals(b.get(k), v, msg) - - def _serialize(self, obj): - trans = TTransport.TMemoryBuffer() - prot = self.protocol_factory.getProtocol(trans) - obj.write(prot) - return trans.getvalue() - - def _deserialize(self, objtype, data): - prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) - ret = objtype() - ret.read(prot) - return ret - - def testWriteOnly(self): - self.assertRaises(NotImplementedError, - self._deserialize, VersioningTestV1, b'{}') - - def testSimpleMessage(self): - v1obj = VersioningTestV1( - begin_in_both=12345, - old_string='aaa', - end_in_both=54321) - expected = dict(begin_in_both=v1obj.begin_in_both, - old_string=v1obj.old_string, - end_in_both=v1obj.end_in_both) - actual = json.loads(self._serialize(v1obj).decode('ascii')) - - self._assertDictEqual(expected, actual) - - def testComplicated(self): - v2obj = VersioningTestV2( - begin_in_both=12345, - newint=1, - newbyte=2, - newshort=3, - newlong=4, - newdouble=5.0, - newstruct=Bonk(message="Hello!", type=123), - newlist=[7, 8, 9], - newset=set([42, 1, 8]), - newmap={1: 2, 2: 3}, - newstring="Hola!", - end_in_both=54321) - expected = dict(begin_in_both=v2obj.begin_in_both, - newint=v2obj.newint, - newbyte=v2obj.newbyte, - newshort=v2obj.newshort, - newlong=v2obj.newlong, - newdouble=v2obj.newdouble, - newstruct=dict(message=v2obj.newstruct.message, - type=v2obj.newstruct.type), - newlist=v2obj.newlist, - newset=list(v2obj.newset), - newmap=v2obj.newmap, - newstring=v2obj.newstring, - end_in_both=v2obj.end_in_both) - - # Need to load/dump because map keys get escaped. - expected = json.loads(json.dumps(expected)) - actual = json.loads(self._serialize(v2obj).decode('ascii')) - self._assertDictEqual(expected, actual) + protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory() + + def _assertDictEqual(self, a, b, msg=None): + if hasattr(self, 'assertDictEqual'): + # assertDictEqual only in Python 2.7. Depends on your machine. + self.assertDictEqual(a, b, msg) + return + + # Substitute implementation not as good as unittest library's + self.assertEquals(len(a), len(b), msg) + for k, v in a.iteritems(): + self.assertTrue(k in b, msg) + self.assertEquals(b.get(k), v, msg) + + def _serialize(self, obj): + trans = TTransport.TMemoryBuffer() + prot = self.protocol_factory.getProtocol(trans) + obj.write(prot) + return trans.getvalue() + + def _deserialize(self, objtype, data): + prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) + ret = objtype() + ret.read(prot) + return ret + + def testWriteOnly(self): + self.assertRaises(NotImplementedError, + self._deserialize, VersioningTestV1, b'{}') + + def testSimpleMessage(self): + v1obj = VersioningTestV1( + begin_in_both=12345, + old_string='aaa', + end_in_both=54321) + expected = dict(begin_in_both=v1obj.begin_in_both, + old_string=v1obj.old_string, + end_in_both=v1obj.end_in_both) + actual = json.loads(self._serialize(v1obj).decode('ascii')) + + self._assertDictEqual(expected, actual) + + def testComplicated(self): + v2obj = VersioningTestV2( + begin_in_both=12345, + newint=1, + newbyte=2, + newshort=3, + newlong=4, + newdouble=5.0, + newstruct=Bonk(message="Hello!", type=123), + newlist=[7, 8, 9], + newset=set([42, 1, 8]), + newmap={1: 2, 2: 3}, + newstring="Hola!", + end_in_both=54321) + expected = dict(begin_in_both=v2obj.begin_in_both, + newint=v2obj.newint, + newbyte=v2obj.newbyte, + newshort=v2obj.newshort, + newlong=v2obj.newlong, + newdouble=v2obj.newdouble, + newstruct=dict(message=v2obj.newstruct.message, + type=v2obj.newstruct.type), + newlist=v2obj.newlist, + newset=list(v2obj.newset), + newmap=v2obj.newmap, + newstring=v2obj.newstring, + end_in_both=v2obj.end_in_both) + + # Need to load/dump because map keys get escaped. + expected = json.loads(json.dumps(expected)) + actual = json.loads(self._serialize(v2obj).decode('ascii')) + self._assertDictEqual(expected, actual) if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/test/py/TestClient.py b/test/py/TestClient.py index 347329e08..bc7650dcc 100755 --- a/test/py/TestClient.py +++ b/test/py/TestClient.py @@ -32,42 +32,42 @@ DEFAULT_LIBDIR_GLOB = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*') class AbstractTest(unittest.TestCase): - def setUp(self): - if options.http_path: - self.transport = THttpClient.THttpClient(options.host, port=options.port, path=options.http_path) - else: - if options.ssl: - from thrift.transport import TSSLSocket - socket = TSSLSocket.TSSLSocket(options.host, options.port, validate=False) - else: - socket = TSocket.TSocket(options.host, options.port) - # frame or buffer depending upon args - self.transport = TTransport.TBufferedTransport(socket) - if options.trans == 'framed': - self.transport = TTransport.TFramedTransport(socket) - elif options.trans == 'buffered': - self.transport = TTransport.TBufferedTransport(socket) - elif options.trans == '': - raise AssertionError('Unknown --transport option: %s' % options.trans) - if options.zlib: - self.transport = TZlibTransport.TZlibTransport(self.transport, 9) - self.transport.open() - protocol = self.get_protocol(self.transport) - self.client = ThriftTest.Client(protocol) - - def tearDown(self): - self.transport.close() - - def testVoid(self): - print('testVoid') - self.client.testVoid() - - def testString(self): - print('testString') - self.assertEqual(self.client.testString('Python' * 20), 'Python' * 20) - self.assertEqual(self.client.testString(''), '') - s1 = u'\b\t\n/\\\\\r{}:パイソン"' - s2 = u"""Afrikaans, Alemannisch, Aragonés, العربية, مصرى, + def setUp(self): + if options.http_path: + self.transport = THttpClient.THttpClient(options.host, port=options.port, path=options.http_path) + else: + if options.ssl: + from thrift.transport import TSSLSocket + socket = TSSLSocket.TSSLSocket(options.host, options.port, validate=False) + else: + socket = TSocket.TSocket(options.host, options.port) + # frame or buffer depending upon args + self.transport = TTransport.TBufferedTransport(socket) + if options.trans == 'framed': + self.transport = TTransport.TFramedTransport(socket) + elif options.trans == 'buffered': + self.transport = TTransport.TBufferedTransport(socket) + elif options.trans == '': + raise AssertionError('Unknown --transport option: %s' % options.trans) + if options.zlib: + self.transport = TZlibTransport.TZlibTransport(self.transport, 9) + self.transport.open() + protocol = self.get_protocol(self.transport) + self.client = ThriftTest.Client(protocol) + + def tearDown(self): + self.transport.close() + + def testVoid(self): + print('testVoid') + self.client.testVoid() + + def testString(self): + print('testString') + self.assertEqual(self.client.testString('Python' * 20), 'Python' * 20) + self.assertEqual(self.client.testString(''), '') + s1 = u'\b\t\n/\\\\\r{}:パイソン"' + s2 = u"""Afrikaans, Alemannisch, Aragonés, العربية, مصرى, Asturianu, Aymar aru, Azərbaycan, Башҡорт, Boarisch, Žemaitėška, Беларуская, Беларуская (тарашкевіца), Български, Bamanankan, বাংলা, Brezhoneg, Bosanski, Català, Mìng-dĕ̤ng-ngṳ̄, Нохчийн, @@ -92,199 +92,199 @@ class AbstractTest(unittest.TestCase): Türkçe, Татарча/Tatarça, Українська, اردو, Tiếng Việt, Volapük, Walon, Winaray, 吴语, isiXhosa, ייִדיש, Yorùbá, Zeêuws, 中文, Bân-lâm-gú, 粵語""" - if sys.version_info[0] == 2 and os.environ.get('THRIFT_TEST_PY_NO_UTF8STRINGS'): - s1 = s1.encode('utf8') - s2 = s2.encode('utf8') - self.assertEqual(self.client.testString(s1), s1) - self.assertEqual(self.client.testString(s2), s2) - - def testBool(self): - print('testBool') - self.assertEqual(self.client.testBool(True), True) - self.assertEqual(self.client.testBool(False), False) - - def testByte(self): - print('testByte') - self.assertEqual(self.client.testByte(63), 63) - self.assertEqual(self.client.testByte(-127), -127) - - def testI32(self): - print('testI32') - self.assertEqual(self.client.testI32(-1), -1) - self.assertEqual(self.client.testI32(0), 0) - - def testI64(self): - print('testI64') - self.assertEqual(self.client.testI64(1), 1) - self.assertEqual(self.client.testI64(-34359738368), -34359738368) - - def testDouble(self): - print('testDouble') - self.assertEqual(self.client.testDouble(-5.235098235), -5.235098235) - self.assertEqual(self.client.testDouble(0), 0) - self.assertEqual(self.client.testDouble(-1), -1) - self.assertEqual(self.client.testDouble(-0.000341012439638598279), -0.000341012439638598279) - - def testBinary(self): - print('testBinary') - val = bytearray([i for i in range(0, 256)]) - self.assertEqual(bytearray(self.client.testBinary(bytes(val))), val) - - def testStruct(self): - print('testStruct') - x = Xtruct() - x.string_thing = "Zero" - x.byte_thing = 1 - x.i32_thing = -3 - x.i64_thing = -5 - y = self.client.testStruct(x) - self.assertEqual(y, x) - - def testNest(self): - print('testNest') - inner = Xtruct(string_thing="Zero", byte_thing=1, i32_thing=-3, i64_thing=-5) - x = Xtruct2(struct_thing=inner, byte_thing=0, i32_thing=0) - y = self.client.testNest(x) - self.assertEqual(y, x) - - def testMap(self): - print('testMap') - x = {0: 1, 1: 2, 2: 3, 3: 4, -1: -2} - y = self.client.testMap(x) - self.assertEqual(y, x) - - def testSet(self): - print('testSet') - x = set([8, 1, 42]) - y = self.client.testSet(x) - self.assertEqual(y, x) - - def testList(self): - print('testList') - x = [1, 4, 9, -42] - y = self.client.testList(x) - self.assertEqual(y, x) - - def testEnum(self): - print('testEnum') - x = Numberz.FIVE - y = self.client.testEnum(x) - self.assertEqual(y, x) - - def testTypedef(self): - print('testTypedef') - x = 0xffffffffffffff # 7 bytes of 0xff - y = self.client.testTypedef(x) - self.assertEqual(y, x) - - def testMapMap(self): - print('testMapMap') - x = { - -4: {-4: -4, -3: -3, -2: -2, -1: -1}, - 4: {4: 4, 3: 3, 2: 2, 1: 1}, - } - y = self.client.testMapMap(42) - self.assertEqual(y, x) - - def testMulti(self): - print('testMulti') - xpected = Xtruct(string_thing='Hello2', byte_thing=74, i32_thing=0xff00ff, i64_thing=0xffffffffd0d0) - y = self.client.testMulti(xpected.byte_thing, - xpected.i32_thing, - xpected.i64_thing, - {0: 'abc'}, - Numberz.FIVE, - 0xf0f0f0) - self.assertEqual(y, xpected) - - def testException(self): - print('testException') - self.client.testException('Safe') - try: - self.client.testException('Xception') - self.fail("should have gotten exception") - except Xception as x: - self.assertEqual(x.errorCode, 1001) - self.assertEqual(x.message, 'Xception') - # TODO ensure same behavior for repr within generated python variants - # ensure exception's repr method works - # x_repr = repr(x) - # self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')') - - try: - self.client.testException('TException') - self.fail("should have gotten exception") - except TException as x: - pass - - # Should not throw - self.client.testException('success') - - def testMultiException(self): - print('testMultiException') - try: - self.client.testMultiException('Xception', 'ignore') - except Xception as ex: - self.assertEqual(ex.errorCode, 1001) - self.assertEqual(ex.message, 'This is an Xception') - - try: - self.client.testMultiException('Xception2', 'ignore') - except Xception2 as ex: - self.assertEqual(ex.errorCode, 2002) - self.assertEqual(ex.struct_thing.string_thing, 'This is an Xception2') - - y = self.client.testMultiException('success', 'foobar') - self.assertEqual(y.string_thing, 'foobar') - - def testOneway(self): - print('testOneway') - start = time.time() - self.client.testOneway(1) # type is int, not float - end = time.time() - self.assertTrue(end - start < 3, - "oneway sleep took %f sec" % (end - start)) - - def testOnewayThenNormal(self): - print('testOnewayThenNormal') - self.client.testOneway(1) # type is int, not float - self.assertEqual(self.client.testString('Python'), 'Python') + if sys.version_info[0] == 2 and os.environ.get('THRIFT_TEST_PY_NO_UTF8STRINGS'): + s1 = s1.encode('utf8') + s2 = s2.encode('utf8') + self.assertEqual(self.client.testString(s1), s1) + self.assertEqual(self.client.testString(s2), s2) + + def testBool(self): + print('testBool') + self.assertEqual(self.client.testBool(True), True) + self.assertEqual(self.client.testBool(False), False) + + def testByte(self): + print('testByte') + self.assertEqual(self.client.testByte(63), 63) + self.assertEqual(self.client.testByte(-127), -127) + + def testI32(self): + print('testI32') + self.assertEqual(self.client.testI32(-1), -1) + self.assertEqual(self.client.testI32(0), 0) + + def testI64(self): + print('testI64') + self.assertEqual(self.client.testI64(1), 1) + self.assertEqual(self.client.testI64(-34359738368), -34359738368) + + def testDouble(self): + print('testDouble') + self.assertEqual(self.client.testDouble(-5.235098235), -5.235098235) + self.assertEqual(self.client.testDouble(0), 0) + self.assertEqual(self.client.testDouble(-1), -1) + self.assertEqual(self.client.testDouble(-0.000341012439638598279), -0.000341012439638598279) + + def testBinary(self): + print('testBinary') + val = bytearray([i for i in range(0, 256)]) + self.assertEqual(bytearray(self.client.testBinary(bytes(val))), val) + + def testStruct(self): + print('testStruct') + x = Xtruct() + x.string_thing = "Zero" + x.byte_thing = 1 + x.i32_thing = -3 + x.i64_thing = -5 + y = self.client.testStruct(x) + self.assertEqual(y, x) + + def testNest(self): + print('testNest') + inner = Xtruct(string_thing="Zero", byte_thing=1, i32_thing=-3, i64_thing=-5) + x = Xtruct2(struct_thing=inner, byte_thing=0, i32_thing=0) + y = self.client.testNest(x) + self.assertEqual(y, x) + + def testMap(self): + print('testMap') + x = {0: 1, 1: 2, 2: 3, 3: 4, -1: -2} + y = self.client.testMap(x) + self.assertEqual(y, x) + + def testSet(self): + print('testSet') + x = set([8, 1, 42]) + y = self.client.testSet(x) + self.assertEqual(y, x) + + def testList(self): + print('testList') + x = [1, 4, 9, -42] + y = self.client.testList(x) + self.assertEqual(y, x) + + def testEnum(self): + print('testEnum') + x = Numberz.FIVE + y = self.client.testEnum(x) + self.assertEqual(y, x) + + def testTypedef(self): + print('testTypedef') + x = 0xffffffffffffff # 7 bytes of 0xff + y = self.client.testTypedef(x) + self.assertEqual(y, x) + + def testMapMap(self): + print('testMapMap') + x = { + -4: {-4: -4, -3: -3, -2: -2, -1: -1}, + 4: {4: 4, 3: 3, 2: 2, 1: 1}, + } + y = self.client.testMapMap(42) + self.assertEqual(y, x) + + def testMulti(self): + print('testMulti') + xpected = Xtruct(string_thing='Hello2', byte_thing=74, i32_thing=0xff00ff, i64_thing=0xffffffffd0d0) + y = self.client.testMulti(xpected.byte_thing, + xpected.i32_thing, + xpected.i64_thing, + {0: 'abc'}, + Numberz.FIVE, + 0xf0f0f0) + self.assertEqual(y, xpected) + + def testException(self): + print('testException') + self.client.testException('Safe') + try: + self.client.testException('Xception') + self.fail("should have gotten exception") + except Xception as x: + self.assertEqual(x.errorCode, 1001) + self.assertEqual(x.message, 'Xception') + # TODO ensure same behavior for repr within generated python variants + # ensure exception's repr method works + # x_repr = repr(x) + # self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')') + + try: + self.client.testException('TException') + self.fail("should have gotten exception") + except TException as x: + pass + + # Should not throw + self.client.testException('success') + + def testMultiException(self): + print('testMultiException') + try: + self.client.testMultiException('Xception', 'ignore') + except Xception as ex: + self.assertEqual(ex.errorCode, 1001) + self.assertEqual(ex.message, 'This is an Xception') + + try: + self.client.testMultiException('Xception2', 'ignore') + except Xception2 as ex: + self.assertEqual(ex.errorCode, 2002) + self.assertEqual(ex.struct_thing.string_thing, 'This is an Xception2') + + y = self.client.testMultiException('success', 'foobar') + self.assertEqual(y.string_thing, 'foobar') + + def testOneway(self): + print('testOneway') + start = time.time() + self.client.testOneway(1) # type is int, not float + end = time.time() + self.assertTrue(end - start < 3, + "oneway sleep took %f sec" % (end - start)) + + def testOnewayThenNormal(self): + print('testOnewayThenNormal') + self.client.testOneway(1) # type is int, not float + self.assertEqual(self.client.testString('Python'), 'Python') class NormalBinaryTest(AbstractTest): - def get_protocol(self, transport): - return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport) + def get_protocol(self, transport): + return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport) class CompactTest(AbstractTest): - def get_protocol(self, transport): - return TCompactProtocol.TCompactProtocolFactory().getProtocol(transport) + def get_protocol(self, transport): + return TCompactProtocol.TCompactProtocolFactory().getProtocol(transport) class JSONTest(AbstractTest): - def get_protocol(self, transport): - return TJSONProtocol.TJSONProtocolFactory().getProtocol(transport) + def get_protocol(self, transport): + return TJSONProtocol.TJSONProtocolFactory().getProtocol(transport) class AcceleratedBinaryTest(AbstractTest): - def get_protocol(self, transport): - return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(transport) + def get_protocol(self, transport): + return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(transport) def suite(): - suite = unittest.TestSuite() - loader = unittest.TestLoader() - if options.proto == 'binary': # look for --proto on cmdline - suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest)) - elif options.proto == 'accel': - suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest)) - elif options.proto == 'compact': - suite.addTest(loader.loadTestsFromTestCase(CompactTest)) - elif options.proto == 'json': - suite.addTest(loader.loadTestsFromTestCase(JSONTest)) - else: - raise AssertionError('Unknown protocol given with --protocol: %s' % options.proto) - return suite + suite = unittest.TestSuite() + loader = unittest.TestLoader() + if options.proto == 'binary': # look for --proto on cmdline + suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest)) + elif options.proto == 'accel': + suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest)) + elif options.proto == 'compact': + suite.addTest(loader.loadTestsFromTestCase(CompactTest)) + elif options.proto == 'json': + suite.addTest(loader.loadTestsFromTestCase(JSONTest)) + else: + raise AssertionError('Unknown protocol given with --protocol: %s' % options.proto) + return suite class OwnArgsTestProgram(unittest.TestProgram): @@ -296,50 +296,50 @@ class OwnArgsTestProgram(unittest.TestProgram): self.createTests() if __name__ == "__main__": - parser = OptionParser() - parser.add_option('--libpydir', type='string', dest='libpydir', - help='include this directory in sys.path for locating library code') - parser.add_option('--genpydir', type='string', dest='genpydir', - help='include this directory in sys.path for locating generated code') - parser.add_option("--port", type="int", dest="port", - help="connect to server at port") - parser.add_option("--host", type="string", dest="host", - help="connect to server") - parser.add_option("--zlib", action="store_true", dest="zlib", - help="use zlib wrapper for compressed transport") - parser.add_option("--ssl", action="store_true", dest="ssl", - help="use SSL for encrypted transport") - parser.add_option("--http", dest="http_path", - help="Use the HTTP transport with the specified path") - parser.add_option('-v', '--verbose', action="store_const", - dest="verbose", const=2, - help="verbose output") - parser.add_option('-q', '--quiet', action="store_const", - dest="verbose", const=0, - help="minimal output") - parser.add_option('--protocol', dest="proto", type="string", - help="protocol to use, one of: accel, binary, compact, json") - parser.add_option('--transport', dest="trans", type="string", - help="transport to use, one of: buffered, framed") - parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary') - options, args = parser.parse_args() - - if options.genpydir: - sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir)) - if options.libpydir: - sys.path.insert(0, glob.glob(options.libpydir)[0]) - else: - sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0]) - - from ThriftTest import ThriftTest - from ThriftTest.ttypes import Xtruct, Xtruct2, Numberz, Xception, Xception2 - from thrift.Thrift import TException - from thrift.transport import TTransport - from thrift.transport import TSocket - from thrift.transport import THttpClient - from thrift.transport import TZlibTransport - from thrift.protocol import TBinaryProtocol - from thrift.protocol import TCompactProtocol - from thrift.protocol import TJSONProtocol - - OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1)) + parser = OptionParser() + parser.add_option('--libpydir', type='string', dest='libpydir', + help='include this directory in sys.path for locating library code') + parser.add_option('--genpydir', type='string', dest='genpydir', + help='include this directory in sys.path for locating generated code') + parser.add_option("--port", type="int", dest="port", + help="connect to server at port") + parser.add_option("--host", type="string", dest="host", + help="connect to server") + parser.add_option("--zlib", action="store_true", dest="zlib", + help="use zlib wrapper for compressed transport") + parser.add_option("--ssl", action="store_true", dest="ssl", + help="use SSL for encrypted transport") + parser.add_option("--http", dest="http_path", + help="Use the HTTP transport with the specified path") + parser.add_option('-v', '--verbose', action="store_const", + dest="verbose", const=2, + help="verbose output") + parser.add_option('-q', '--quiet', action="store_const", + dest="verbose", const=0, + help="minimal output") + parser.add_option('--protocol', dest="proto", type="string", + help="protocol to use, one of: accel, binary, compact, json") + parser.add_option('--transport', dest="trans", type="string", + help="transport to use, one of: buffered, framed") + parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary') + options, args = parser.parse_args() + + if options.genpydir: + sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir)) + if options.libpydir: + sys.path.insert(0, glob.glob(options.libpydir)[0]) + else: + sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0]) + + from ThriftTest import ThriftTest + from ThriftTest.ttypes import Xtruct, Xtruct2, Numberz, Xception, Xception2 + from thrift.Thrift import TException + from thrift.transport import TTransport + from thrift.transport import TSocket + from thrift.transport import THttpClient + from thrift.transport import TZlibTransport + from thrift.protocol import TBinaryProtocol + from thrift.protocol import TCompactProtocol + from thrift.protocol import TJSONProtocol + + OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1)) diff --git a/test/py/TestEof.py b/test/py/TestEof.py index 661463822..0239fc621 100755 --- a/test/py/TestEof.py +++ b/test/py/TestEof.py @@ -28,99 +28,99 @@ import unittest class TestEof(unittest.TestCase): - def make_data(self, pfactory=None): - trans = TTransport.TMemoryBuffer() - if pfactory: - prot = pfactory.getProtocol(trans) - else: - prot = TBinaryProtocol.TBinaryProtocol(trans) - - x = Xtruct() - x.string_thing = "Zero" - x.byte_thing = 0 - - x.write(prot) - - x = Xtruct() - x.string_thing = "One" - x.byte_thing = 1 - - x.write(prot) - - return trans.getvalue() - - def testTransportReadAll(self): - """Test that readAll on any type of transport throws an EOFError""" - trans = TTransport.TMemoryBuffer(self.make_data()) - trans.readAll(1) - - try: - trans.readAll(10000) - except EOFError: - return - - self.fail("Should have gotten EOFError") - - def eofTestHelper(self, pfactory): - trans = TTransport.TMemoryBuffer(self.make_data(pfactory)) - prot = pfactory.getProtocol(trans) - - x = Xtruct() - x.read(prot) - self.assertEqual(x.string_thing, "Zero") - self.assertEqual(x.byte_thing, 0) - - x = Xtruct() - x.read(prot) - self.assertEqual(x.string_thing, "One") - self.assertEqual(x.byte_thing, 1) - - try: - x = Xtruct() - x.read(prot) - except EOFError: - return - - self.fail("Should have gotten EOFError") - - def eofTestHelperStress(self, pfactory): - """Teest the ability of TBinaryProtocol to deal with the removal of every byte in the file""" - # TODO: we should make sure this covers more of the code paths - - data = self.make_data(pfactory) - for i in range(0, len(data) + 1): - trans = TTransport.TMemoryBuffer(data[0:i]) - prot = pfactory.getProtocol(trans) - try: + def make_data(self, pfactory=None): + trans = TTransport.TMemoryBuffer() + if pfactory: + prot = pfactory.getProtocol(trans) + else: + prot = TBinaryProtocol.TBinaryProtocol(trans) + x = Xtruct() - x.read(prot) - x.read(prot) - x.read(prot) - except EOFError: - continue - self.fail("Should have gotten an EOFError") + x.string_thing = "Zero" + x.byte_thing = 0 + + x.write(prot) + + x = Xtruct() + x.string_thing = "One" + x.byte_thing = 1 + + x.write(prot) - def testBinaryProtocolEof(self): - """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream""" - self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory()) - self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory()) + return trans.getvalue() - def testBinaryProtocolAcceleratedEof(self): - """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream""" - self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory()) - self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory()) + def testTransportReadAll(self): + """Test that readAll on any type of transport throws an EOFError""" + trans = TTransport.TMemoryBuffer(self.make_data()) + trans.readAll(1) - def testCompactProtocolEof(self): - """Test that TCompactProtocol throws an EOFError when it reaches the end of the stream""" - self.eofTestHelper(TCompactProtocol.TCompactProtocolFactory()) - self.eofTestHelperStress(TCompactProtocol.TCompactProtocolFactory()) + try: + trans.readAll(10000) + except EOFError: + return + + self.fail("Should have gotten EOFError") + + def eofTestHelper(self, pfactory): + trans = TTransport.TMemoryBuffer(self.make_data(pfactory)) + prot = pfactory.getProtocol(trans) + + x = Xtruct() + x.read(prot) + self.assertEqual(x.string_thing, "Zero") + self.assertEqual(x.byte_thing, 0) + + x = Xtruct() + x.read(prot) + self.assertEqual(x.string_thing, "One") + self.assertEqual(x.byte_thing, 1) + + try: + x = Xtruct() + x.read(prot) + except EOFError: + return + + self.fail("Should have gotten EOFError") + + def eofTestHelperStress(self, pfactory): + """Teest the ability of TBinaryProtocol to deal with the removal of every byte in the file""" + # TODO: we should make sure this covers more of the code paths + + data = self.make_data(pfactory) + for i in range(0, len(data) + 1): + trans = TTransport.TMemoryBuffer(data[0:i]) + prot = pfactory.getProtocol(trans) + try: + x = Xtruct() + x.read(prot) + x.read(prot) + x.read(prot) + except EOFError: + continue + self.fail("Should have gotten an EOFError") + + def testBinaryProtocolEof(self): + """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream""" + self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory()) + self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory()) + + def testBinaryProtocolAcceleratedEof(self): + """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream""" + self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory()) + self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory()) + + def testCompactProtocolEof(self): + """Test that TCompactProtocol throws an EOFError when it reaches the end of the stream""" + self.eofTestHelper(TCompactProtocol.TCompactProtocolFactory()) + self.eofTestHelperStress(TCompactProtocol.TCompactProtocolFactory()) def suite(): - suite = unittest.TestSuite() - loader = unittest.TestLoader() - suite.addTest(loader.loadTestsFromTestCase(TestEof)) - return suite + suite = unittest.TestSuite() + loader = unittest.TestLoader() + suite.addTest(loader.loadTestsFromTestCase(TestEof)) + return suite if __name__ == "__main__": - unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) + unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/test/py/TestFrozen.py b/test/py/TestFrozen.py index 76750ad88..30a6a557f 100755 --- a/test/py/TestFrozen.py +++ b/test/py/TestFrozen.py @@ -28,89 +28,89 @@ import unittest class TestFrozenBase(unittest.TestCase): - def _roundtrip(self, src, dst): - otrans = TTransport.TMemoryBuffer() - optoro = self.protocol(otrans) - src.write(optoro) - itrans = TTransport.TMemoryBuffer(otrans.getvalue()) - iproto = self.protocol(itrans) - return dst.read(iproto) or dst - - def test_dict_is_hashable_only_after_frozen(self): - d0 = {} - self.assertFalse(isinstance(d0, collections.Hashable)) - d1 = TFrozenDict(d0) - self.assertTrue(isinstance(d1, collections.Hashable)) - - def test_struct_with_collection_fields(self): - pass - - def test_set(self): - """Test that annotated set field can be serialized and deserialized""" - x = CompactProtoTestStruct(set_byte_map={ - frozenset([42, 100, -100]): 99, - frozenset([0]): 100, - frozenset([]): 0, - }) - x2 = self._roundtrip(x, CompactProtoTestStruct()) - self.assertEqual(x2.set_byte_map[frozenset([42, 100, -100])], 99) - self.assertEqual(x2.set_byte_map[frozenset([0])], 100) - self.assertEqual(x2.set_byte_map[frozenset([])], 0) - - def test_map(self): - """Test that annotated map field can be serialized and deserialized""" - x = CompactProtoTestStruct(map_byte_map={ - TFrozenDict({42: 42, 100: -100}): 99, - TFrozenDict({0: 0}): 100, - TFrozenDict({}): 0, - }) - x2 = self._roundtrip(x, CompactProtoTestStruct()) - self.assertEqual(x2.map_byte_map[TFrozenDict({42: 42, 100: -100})], 99) - self.assertEqual(x2.map_byte_map[TFrozenDict({0: 0})], 100) - self.assertEqual(x2.map_byte_map[TFrozenDict({})], 0) - - def test_list(self): - """Test that annotated list field can be serialized and deserialized""" - x = CompactProtoTestStruct(list_byte_map={ - (42, 100, -100): 99, - (0,): 100, - (): 0, - }) - x2 = self._roundtrip(x, CompactProtoTestStruct()) - self.assertEqual(x2.list_byte_map[(42, 100, -100)], 99) - self.assertEqual(x2.list_byte_map[(0,)], 100) - self.assertEqual(x2.list_byte_map[()], 0) - - def test_empty_struct(self): - """Test that annotated empty struct can be serialized and deserialized""" - x = CompactProtoTestStruct(empty_struct_field=Empty()) - x2 = self._roundtrip(x, CompactProtoTestStruct()) - self.assertEqual(x2.empty_struct_field, Empty()) - - def test_struct(self): - """Test that annotated struct can be serialized and deserialized""" - x = Wrapper(foo=Empty()) - self.assertEqual(x.foo, Empty()) - x2 = self._roundtrip(x, Wrapper) - self.assertEqual(x2.foo, Empty()) + def _roundtrip(self, src, dst): + otrans = TTransport.TMemoryBuffer() + optoro = self.protocol(otrans) + src.write(optoro) + itrans = TTransport.TMemoryBuffer(otrans.getvalue()) + iproto = self.protocol(itrans) + return dst.read(iproto) or dst + + def test_dict_is_hashable_only_after_frozen(self): + d0 = {} + self.assertFalse(isinstance(d0, collections.Hashable)) + d1 = TFrozenDict(d0) + self.assertTrue(isinstance(d1, collections.Hashable)) + + def test_struct_with_collection_fields(self): + pass + + def test_set(self): + """Test that annotated set field can be serialized and deserialized""" + x = CompactProtoTestStruct(set_byte_map={ + frozenset([42, 100, -100]): 99, + frozenset([0]): 100, + frozenset([]): 0, + }) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.set_byte_map[frozenset([42, 100, -100])], 99) + self.assertEqual(x2.set_byte_map[frozenset([0])], 100) + self.assertEqual(x2.set_byte_map[frozenset([])], 0) + + def test_map(self): + """Test that annotated map field can be serialized and deserialized""" + x = CompactProtoTestStruct(map_byte_map={ + TFrozenDict({42: 42, 100: -100}): 99, + TFrozenDict({0: 0}): 100, + TFrozenDict({}): 0, + }) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.map_byte_map[TFrozenDict({42: 42, 100: -100})], 99) + self.assertEqual(x2.map_byte_map[TFrozenDict({0: 0})], 100) + self.assertEqual(x2.map_byte_map[TFrozenDict({})], 0) + + def test_list(self): + """Test that annotated list field can be serialized and deserialized""" + x = CompactProtoTestStruct(list_byte_map={ + (42, 100, -100): 99, + (0,): 100, + (): 0, + }) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.list_byte_map[(42, 100, -100)], 99) + self.assertEqual(x2.list_byte_map[(0,)], 100) + self.assertEqual(x2.list_byte_map[()], 0) + + def test_empty_struct(self): + """Test that annotated empty struct can be serialized and deserialized""" + x = CompactProtoTestStruct(empty_struct_field=Empty()) + x2 = self._roundtrip(x, CompactProtoTestStruct()) + self.assertEqual(x2.empty_struct_field, Empty()) + + def test_struct(self): + """Test that annotated struct can be serialized and deserialized""" + x = Wrapper(foo=Empty()) + self.assertEqual(x.foo, Empty()) + x2 = self._roundtrip(x, Wrapper) + self.assertEqual(x2.foo, Empty()) class TestFrozen(TestFrozenBase): - def protocol(self, trans): - return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(trans) + def protocol(self, trans): + return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(trans) class TestFrozenAccelerated(TestFrozenBase): - def protocol(self, trans): - return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(trans) + def protocol(self, trans): + return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(trans) def suite(): - suite = unittest.TestSuite() - loader = unittest.TestLoader() - suite.addTest(loader.loadTestsFromTestCase(TestFrozen)) - suite.addTest(loader.loadTestsFromTestCase(TestFrozenAccelerated)) - return suite + suite = unittest.TestSuite() + loader = unittest.TestLoader() + suite.addTest(loader.loadTestsFromTestCase(TestFrozen)) + suite.addTest(loader.loadTestsFromTestCase(TestFrozenAccelerated)) + return suite if __name__ == "__main__": - unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) + unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/test/py/TestServer.py b/test/py/TestServer.py index f12a9fe76..ef93509b2 100755 --- a/test/py/TestServer.py +++ b/test/py/TestServer.py @@ -32,287 +32,287 @@ DEFAULT_LIBDIR_GLOB = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*') class TestHandler(object): - def testVoid(self): - if options.verbose > 1: - logging.info('testVoid()') - - def testString(self, str): - if options.verbose > 1: - logging.info('testString(%s)' % str) - return str - - def testBool(self, boolean): - if options.verbose > 1: - logging.info('testBool(%s)' % str(boolean).lower()) - return boolean - - def testByte(self, byte): - if options.verbose > 1: - logging.info('testByte(%d)' % byte) - return byte - - def testI16(self, i16): - if options.verbose > 1: - logging.info('testI16(%d)' % i16) - return i16 - - def testI32(self, i32): - if options.verbose > 1: - logging.info('testI32(%d)' % i32) - return i32 - - def testI64(self, i64): - if options.verbose > 1: - logging.info('testI64(%d)' % i64) - return i64 - - def testDouble(self, dub): - if options.verbose > 1: - logging.info('testDouble(%f)' % dub) - return dub - - def testBinary(self, thing): - if options.verbose > 1: - logging.info('testBinary()') # TODO: hex output - return thing - - def testStruct(self, thing): - if options.verbose > 1: - logging.info('testStruct({%s, %s, %s, %s})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing)) - return thing - - def testException(self, arg): - # if options.verbose > 1: - logging.info('testException(%s)' % arg) - if arg == 'Xception': - raise Xception(errorCode=1001, message=arg) - elif arg == 'TException': - raise TException(message='This is a TException') - - def testMultiException(self, arg0, arg1): - if options.verbose > 1: - logging.info('testMultiException(%s, %s)' % (arg0, arg1)) - if arg0 == 'Xception': - raise Xception(errorCode=1001, message='This is an Xception') - elif arg0 == 'Xception2': - raise Xception2( - errorCode=2002, - struct_thing=Xtruct(string_thing='This is an Xception2')) - return Xtruct(string_thing=arg1) - - def testOneway(self, seconds): - if options.verbose > 1: - logging.info('testOneway(%d) => sleeping...' % seconds) - time.sleep(seconds / 3) # be quick - if options.verbose > 1: - logging.info('done sleeping') - - def testNest(self, thing): - if options.verbose > 1: - logging.info('testNest(%s)' % thing) - return thing - - def testMap(self, thing): - if options.verbose > 1: - logging.info('testMap(%s)' % thing) - return thing - - def testStringMap(self, thing): - if options.verbose > 1: - logging.info('testStringMap(%s)' % thing) - return thing - - def testSet(self, thing): - if options.verbose > 1: - logging.info('testSet(%s)' % thing) - return thing - - def testList(self, thing): - if options.verbose > 1: - logging.info('testList(%s)' % thing) - return thing - - def testEnum(self, thing): - if options.verbose > 1: - logging.info('testEnum(%s)' % thing) - return thing - - def testTypedef(self, thing): - if options.verbose > 1: - logging.info('testTypedef(%s)' % thing) - return thing - - def testMapMap(self, thing): - if options.verbose > 1: - logging.info('testMapMap(%s)' % thing) - return { - -4: { - -4: -4, - -3: -3, - -2: -2, - -1: -1, - }, - 4: { - 4: 4, - 3: 3, - 2: 2, - 1: 1, - }, - } - - def testInsanity(self, argument): - if options.verbose > 1: - logging.info('testInsanity(%s)' % argument) - return { - 1: { - 2: argument, - 3: argument, - }, - 2: {6: Insanity()}, - } - - def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5): - if options.verbose > 1: - logging.info('testMulti(%s)' % [arg0, arg1, arg2, arg3, arg4, arg5]) - return Xtruct(string_thing='Hello2', - byte_thing=arg0, i32_thing=arg1, i64_thing=arg2) + def testVoid(self): + if options.verbose > 1: + logging.info('testVoid()') + + def testString(self, str): + if options.verbose > 1: + logging.info('testString(%s)' % str) + return str + + def testBool(self, boolean): + if options.verbose > 1: + logging.info('testBool(%s)' % str(boolean).lower()) + return boolean + + def testByte(self, byte): + if options.verbose > 1: + logging.info('testByte(%d)' % byte) + return byte + + def testI16(self, i16): + if options.verbose > 1: + logging.info('testI16(%d)' % i16) + return i16 + + def testI32(self, i32): + if options.verbose > 1: + logging.info('testI32(%d)' % i32) + return i32 + + def testI64(self, i64): + if options.verbose > 1: + logging.info('testI64(%d)' % i64) + return i64 + + def testDouble(self, dub): + if options.verbose > 1: + logging.info('testDouble(%f)' % dub) + return dub + + def testBinary(self, thing): + if options.verbose > 1: + logging.info('testBinary()') # TODO: hex output + return thing + + def testStruct(self, thing): + if options.verbose > 1: + logging.info('testStruct({%s, %s, %s, %s})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing)) + return thing + + def testException(self, arg): + # if options.verbose > 1: + logging.info('testException(%s)' % arg) + if arg == 'Xception': + raise Xception(errorCode=1001, message=arg) + elif arg == 'TException': + raise TException(message='This is a TException') + + def testMultiException(self, arg0, arg1): + if options.verbose > 1: + logging.info('testMultiException(%s, %s)' % (arg0, arg1)) + if arg0 == 'Xception': + raise Xception(errorCode=1001, message='This is an Xception') + elif arg0 == 'Xception2': + raise Xception2( + errorCode=2002, + struct_thing=Xtruct(string_thing='This is an Xception2')) + return Xtruct(string_thing=arg1) + + def testOneway(self, seconds): + if options.verbose > 1: + logging.info('testOneway(%d) => sleeping...' % seconds) + time.sleep(seconds / 3) # be quick + if options.verbose > 1: + logging.info('done sleeping') + + def testNest(self, thing): + if options.verbose > 1: + logging.info('testNest(%s)' % thing) + return thing + + def testMap(self, thing): + if options.verbose > 1: + logging.info('testMap(%s)' % thing) + return thing + + def testStringMap(self, thing): + if options.verbose > 1: + logging.info('testStringMap(%s)' % thing) + return thing + + def testSet(self, thing): + if options.verbose > 1: + logging.info('testSet(%s)' % thing) + return thing + + def testList(self, thing): + if options.verbose > 1: + logging.info('testList(%s)' % thing) + return thing + + def testEnum(self, thing): + if options.verbose > 1: + logging.info('testEnum(%s)' % thing) + return thing + + def testTypedef(self, thing): + if options.verbose > 1: + logging.info('testTypedef(%s)' % thing) + return thing + + def testMapMap(self, thing): + if options.verbose > 1: + logging.info('testMapMap(%s)' % thing) + return { + -4: { + -4: -4, + -3: -3, + -2: -2, + -1: -1, + }, + 4: { + 4: 4, + 3: 3, + 2: 2, + 1: 1, + }, + } + + def testInsanity(self, argument): + if options.verbose > 1: + logging.info('testInsanity(%s)' % argument) + return { + 1: { + 2: argument, + 3: argument, + }, + 2: {6: Insanity()}, + } + + def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5): + if options.verbose > 1: + logging.info('testMulti(%s)' % [arg0, arg1, arg2, arg3, arg4, arg5]) + return Xtruct(string_thing='Hello2', + byte_thing=arg0, i32_thing=arg1, i64_thing=arg2) def main(options): - # set up the protocol factory form the --protocol option - prot_factories = { - 'binary': TBinaryProtocol.TBinaryProtocolFactory, - 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory, - 'compact': TCompactProtocol.TCompactProtocolFactory, - 'json': TJSONProtocol.TJSONProtocolFactory, - } - pfactory_cls = prot_factories.get(options.proto, None) - if pfactory_cls is None: - raise AssertionError('Unknown --protocol option: %s' % options.proto) - pfactory = pfactory_cls() - try: - pfactory.string_length_limit = options.string_limit - pfactory.container_length_limit = options.container_limit - except: - # Ignore errors for those protocols that does not support length limit - pass - - # get the server type (TSimpleServer, TNonblockingServer, etc...) - if len(args) > 1: - raise AssertionError('Only one server type may be specified, not multiple types.') - server_type = args[0] - - # Set up the handler and processor objects - handler = TestHandler() - processor = ThriftTest.Processor(handler) - - # Handle THttpServer as a special case - if server_type == 'THttpServer': - server = THttpServer.THttpServer(processor, ('', options.port), pfactory) - server.serve() - sys.exit(0) - - # set up server transport and transport factory - - abs_key_path = os.path.join(os.path.dirname(SCRIPT_DIR), 'keys', 'server.pem') - - host = None - if options.ssl: - from thrift.transport import TSSLSocket - transport = TSSLSocket.TSSLServerSocket(host, options.port, certfile=abs_key_path) - else: - transport = TSocket.TServerSocket(host, options.port) - tfactory = TTransport.TBufferedTransportFactory() - if options.trans == 'buffered': - tfactory = TTransport.TBufferedTransportFactory() - elif options.trans == 'framed': - tfactory = TTransport.TFramedTransportFactory() - elif options.trans == '': - raise AssertionError('Unknown --transport option: %s' % options.trans) - else: + # set up the protocol factory form the --protocol option + prot_factories = { + 'binary': TBinaryProtocol.TBinaryProtocolFactory, + 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory, + 'compact': TCompactProtocol.TCompactProtocolFactory, + 'json': TJSONProtocol.TJSONProtocolFactory, + } + pfactory_cls = prot_factories.get(options.proto, None) + if pfactory_cls is None: + raise AssertionError('Unknown --protocol option: %s' % options.proto) + pfactory = pfactory_cls() + try: + pfactory.string_length_limit = options.string_limit + pfactory.container_length_limit = options.container_limit + except: + # Ignore errors for those protocols that does not support length limit + pass + + # get the server type (TSimpleServer, TNonblockingServer, etc...) + if len(args) > 1: + raise AssertionError('Only one server type may be specified, not multiple types.') + server_type = args[0] + + # Set up the handler and processor objects + handler = TestHandler() + processor = ThriftTest.Processor(handler) + + # Handle THttpServer as a special case + if server_type == 'THttpServer': + server = THttpServer.THttpServer(processor, ('', options.port), pfactory) + server.serve() + sys.exit(0) + + # set up server transport and transport factory + + abs_key_path = os.path.join(os.path.dirname(SCRIPT_DIR), 'keys', 'server.pem') + + host = None + if options.ssl: + from thrift.transport import TSSLSocket + transport = TSSLSocket.TSSLServerSocket(host, options.port, certfile=abs_key_path) + else: + transport = TSocket.TServerSocket(host, options.port) tfactory = TTransport.TBufferedTransportFactory() - # if --zlib, then wrap server transport, and use a different transport factory - if options.zlib: - transport = TZlibTransport.TZlibTransport(transport) # wrap with zlib - tfactory = TZlibTransport.TZlibTransportFactory() - - # do server-specific setup here: - if server_type == "TNonblockingServer": - server = TNonblockingServer.TNonblockingServer(processor, transport, inputProtocolFactory=pfactory) - elif server_type == "TProcessPoolServer": - import signal - from thrift.server import TProcessPoolServer - server = TProcessPoolServer.TProcessPoolServer(processor, transport, tfactory, pfactory) - server.setNumWorkers(5) - - def set_alarm(): - def clean_shutdown(signum, frame): - for worker in server.workers: - if options.verbose > 0: - logging.info('Terminating worker: %s' % worker) - worker.terminate() - if options.verbose > 0: - logging.info('Requesting server to stop()') - try: - server.stop() - except: - pass - signal.signal(signal.SIGALRM, clean_shutdown) - signal.alarm(4) - set_alarm() - else: - # look up server class dynamically to instantiate server - ServerClass = getattr(TServer, server_type) - server = ServerClass(processor, transport, tfactory, pfactory) - # enter server main loop - server.serve() + if options.trans == 'buffered': + tfactory = TTransport.TBufferedTransportFactory() + elif options.trans == 'framed': + tfactory = TTransport.TFramedTransportFactory() + elif options.trans == '': + raise AssertionError('Unknown --transport option: %s' % options.trans) + else: + tfactory = TTransport.TBufferedTransportFactory() + # if --zlib, then wrap server transport, and use a different transport factory + if options.zlib: + transport = TZlibTransport.TZlibTransport(transport) # wrap with zlib + tfactory = TZlibTransport.TZlibTransportFactory() + + # do server-specific setup here: + if server_type == "TNonblockingServer": + server = TNonblockingServer.TNonblockingServer(processor, transport, inputProtocolFactory=pfactory) + elif server_type == "TProcessPoolServer": + import signal + from thrift.server import TProcessPoolServer + server = TProcessPoolServer.TProcessPoolServer(processor, transport, tfactory, pfactory) + server.setNumWorkers(5) + + def set_alarm(): + def clean_shutdown(signum, frame): + for worker in server.workers: + if options.verbose > 0: + logging.info('Terminating worker: %s' % worker) + worker.terminate() + if options.verbose > 0: + logging.info('Requesting server to stop()') + try: + server.stop() + except: + pass + signal.signal(signal.SIGALRM, clean_shutdown) + signal.alarm(4) + set_alarm() + else: + # look up server class dynamically to instantiate server + ServerClass = getattr(TServer, server_type) + server = ServerClass(processor, transport, tfactory, pfactory) + # enter server main loop + server.serve() if __name__ == '__main__': - parser = OptionParser() - parser.add_option('--libpydir', type='string', dest='libpydir', - help='include this directory to sys.path for locating library code') - parser.add_option('--genpydir', type='string', dest='genpydir', - default='gen-py', - help='include this directory to sys.path for locating generated code') - parser.add_option("--port", type="int", dest="port", - help="port number for server to listen on") - parser.add_option("--zlib", action="store_true", dest="zlib", - help="use zlib wrapper for compressed transport") - parser.add_option("--ssl", action="store_true", dest="ssl", - help="use SSL for encrypted transport") - parser.add_option('-v', '--verbose', action="store_const", - dest="verbose", const=2, - help="verbose output") - parser.add_option('-q', '--quiet', action="store_const", - dest="verbose", const=0, - help="minimal output") - parser.add_option('--protocol', dest="proto", type="string", - help="protocol to use, one of: accel, binary, compact, json") - parser.add_option('--transport', dest="trans", type="string", - help="transport to use, one of: buffered, framed") - parser.add_option('--container-limit', dest='container_limit', type='int', default=None) - parser.add_option('--string-limit', dest='string_limit', type='int', default=None) - parser.set_defaults(port=9090, verbose=1, proto='binary') - options, args = parser.parse_args() - - # Print TServer log to stdout so that the test-runner can redirect it to log files - logging.basicConfig(level=options.verbose) - - sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir)) - if options.libpydir: - sys.path.insert(0, glob.glob(options.libpydir)[0]) - else: - sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0]) - - from ThriftTest import ThriftTest - from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity - from thrift.Thrift import TException - from thrift.transport import TTransport - from thrift.transport import TSocket - from thrift.transport import TZlibTransport - from thrift.protocol import TBinaryProtocol - from thrift.protocol import TCompactProtocol - from thrift.protocol import TJSONProtocol - from thrift.server import TServer, TNonblockingServer, THttpServer - - sys.exit(main(options)) + parser = OptionParser() + parser.add_option('--libpydir', type='string', dest='libpydir', + help='include this directory to sys.path for locating library code') + parser.add_option('--genpydir', type='string', dest='genpydir', + default='gen-py', + help='include this directory to sys.path for locating generated code') + parser.add_option("--port", type="int", dest="port", + help="port number for server to listen on") + parser.add_option("--zlib", action="store_true", dest="zlib", + help="use zlib wrapper for compressed transport") + parser.add_option("--ssl", action="store_true", dest="ssl", + help="use SSL for encrypted transport") + parser.add_option('-v', '--verbose', action="store_const", + dest="verbose", const=2, + help="verbose output") + parser.add_option('-q', '--quiet', action="store_const", + dest="verbose", const=0, + help="minimal output") + parser.add_option('--protocol', dest="proto", type="string", + help="protocol to use, one of: accel, binary, compact, json") + parser.add_option('--transport', dest="trans", type="string", + help="transport to use, one of: buffered, framed") + parser.add_option('--container-limit', dest='container_limit', type='int', default=None) + parser.add_option('--string-limit', dest='string_limit', type='int', default=None) + parser.set_defaults(port=9090, verbose=1, proto='binary') + options, args = parser.parse_args() + + # Print TServer log to stdout so that the test-runner can redirect it to log files + logging.basicConfig(level=options.verbose) + + sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir)) + if options.libpydir: + sys.path.insert(0, glob.glob(options.libpydir)[0]) + else: + sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0]) + + from ThriftTest import ThriftTest + from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity + from thrift.Thrift import TException + from thrift.transport import TTransport + from thrift.transport import TSocket + from thrift.transport import TZlibTransport + from thrift.protocol import TBinaryProtocol + from thrift.protocol import TCompactProtocol + from thrift.protocol import TJSONProtocol + from thrift.server import TServer, TNonblockingServer, THttpServer + + sys.exit(main(options)) diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py index a01be85ac..9b578cca4 100755 --- a/test/py/TestSocket.py +++ b/test/py/TestSocket.py @@ -68,10 +68,10 @@ class TimeoutTest(unittest.TestCase): self.assert_(time.time() - starttime < 5.0) if __name__ == '__main__': - suite = unittest.TestSuite() - loader = unittest.TestLoader() + suite = unittest.TestSuite() + loader = unittest.TestLoader() - suite.addTest(loader.loadTestsFromTestCase(TimeoutTest)) + suite.addTest(loader.loadTestsFromTestCase(TimeoutTest)) - testRunner = unittest.TextTestRunner(verbosity=2) - testRunner.run(suite) + testRunner = unittest.TextTestRunner(verbosity=2) + testRunner.run(suite) diff --git a/test/test.py b/test/test.py index a5bcd9bb2..42babebb3 100755 --- a/test/test.py +++ b/test/test.py @@ -46,124 +46,124 @@ CONFIG_FILE = 'tests.json' def run_cross_tests(server_match, client_match, jobs, skip_known_failures, retry_count): - logger = multiprocessing.get_logger() - logger.debug('Collecting tests') - with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp: - j = json.load(fp) - tests = crossrunner.collect_cross_tests(j, server_match, client_match) - if not tests: - print('No test found that matches the criteria', file=sys.stderr) - print(' servers: %s' % server_match, file=sys.stderr) - print(' clients: %s' % client_match, file=sys.stderr) - return False - if skip_known_failures: - logger.debug('Skipping known failures') - known = crossrunner.load_known_failures(TEST_DIR) - tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests)) - - dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, TEST_DIR_RELATIVE, jobs) - logger.debug('Executing %d tests' % len(tests)) - try: - for r in [dispatcher.dispatch(test, retry_count) for test in tests]: - r.wait() - logger.debug('Waiting for completion') - return dispatcher.wait() - except (KeyboardInterrupt, SystemExit): - logger.debug('Interrupted, shutting down') - dispatcher.terminate() - return False + logger = multiprocessing.get_logger() + logger.debug('Collecting tests') + with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp: + j = json.load(fp) + tests = crossrunner.collect_cross_tests(j, server_match, client_match) + if not tests: + print('No test found that matches the criteria', file=sys.stderr) + print(' servers: %s' % server_match, file=sys.stderr) + print(' clients: %s' % client_match, file=sys.stderr) + return False + if skip_known_failures: + logger.debug('Skipping known failures') + known = crossrunner.load_known_failures(TEST_DIR) + tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests)) + + dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, TEST_DIR_RELATIVE, jobs) + logger.debug('Executing %d tests' % len(tests)) + try: + for r in [dispatcher.dispatch(test, retry_count) for test in tests]: + r.wait() + logger.debug('Waiting for completion') + return dispatcher.wait() + except (KeyboardInterrupt, SystemExit): + logger.debug('Interrupted, shutting down') + dispatcher.terminate() + return False def run_feature_tests(server_match, feature_match, jobs, skip_known_failures, retry_count): - basedir = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE) - logger = multiprocessing.get_logger() - logger.debug('Collecting tests') - with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp: - j = json.load(fp) - with open(path_join(basedir, CONFIG_FILE), 'r') as fp: - j2 = json.load(fp) - tests = crossrunner.collect_feature_tests(j, j2, server_match, feature_match) - if not tests: - print('No test found that matches the criteria', file=sys.stderr) - print(' servers: %s' % server_match, file=sys.stderr) - print(' features: %s' % feature_match, file=sys.stderr) - return False - if skip_known_failures: - logger.debug('Skipping known failures') - known = crossrunner.load_known_failures(basedir) - tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests)) - - dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, FEATURE_DIR_RELATIVE, jobs) - logger.debug('Executing %d tests' % len(tests)) - try: - for r in [dispatcher.dispatch(test, retry_count) for test in tests]: - r.wait() - logger.debug('Waiting for completion') - return dispatcher.wait() - except (KeyboardInterrupt, SystemExit): - logger.debug('Interrupted, shutting down') - dispatcher.terminate() - return False + basedir = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE) + logger = multiprocessing.get_logger() + logger.debug('Collecting tests') + with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp: + j = json.load(fp) + with open(path_join(basedir, CONFIG_FILE), 'r') as fp: + j2 = json.load(fp) + tests = crossrunner.collect_feature_tests(j, j2, server_match, feature_match) + if not tests: + print('No test found that matches the criteria', file=sys.stderr) + print(' servers: %s' % server_match, file=sys.stderr) + print(' features: %s' % feature_match, file=sys.stderr) + return False + if skip_known_failures: + logger.debug('Skipping known failures') + known = crossrunner.load_known_failures(basedir) + tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests)) + + dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, FEATURE_DIR_RELATIVE, jobs) + logger.debug('Executing %d tests' % len(tests)) + try: + for r in [dispatcher.dispatch(test, retry_count) for test in tests]: + r.wait() + logger.debug('Waiting for completion') + return dispatcher.wait() + except (KeyboardInterrupt, SystemExit): + logger.debug('Interrupted, shutting down') + dispatcher.terminate() + return False def default_concurrenty(): - try: - return int(os.environ.get('THRIFT_CROSSTEST_CONCURRENCY')) - except (TypeError, ValueError): - # Since much time is spent sleeping, use many threads - return int(multiprocessing.cpu_count() * 1.25) + 1 + try: + return int(os.environ.get('THRIFT_CROSSTEST_CONCURRENCY')) + except (TypeError, ValueError): + # Since much time is spent sleeping, use many threads + return int(multiprocessing.cpu_count() * 1.25) + 1 def main(argv): - parser = argparse.ArgumentParser() - parser.add_argument('--server', default='', nargs='*', - help='list of servers to test') - parser.add_argument('--client', default='', nargs='*', - help='list of clients to test') - parser.add_argument('-F', '--features', nargs='*', default=None, - help='run server feature tests instead of cross language tests') - parser.add_argument('-s', '--skip-known-failures', action='store_true', dest='skip_known_failures', - help='do not execute tests that are known to fail') - parser.add_argument('-r', '--retry-count', type=int, - default=0, help='maximum retry on failure') - parser.add_argument('-j', '--jobs', type=int, - default=default_concurrenty(), - help='number of concurrent test executions') - - g = parser.add_argument_group(title='Advanced') - g.add_argument('-v', '--verbose', action='store_const', - dest='log_level', const=logging.DEBUG, default=logging.WARNING, - help='show debug output for test runner') - g.add_argument('-P', '--print-expected-failures', choices=['merge', 'overwrite'], - dest='print_failures', - help="generate expected failures based on last result and print to stdout") - g.add_argument('-U', '--update-expected-failures', choices=['merge', 'overwrite'], - dest='update_failures', - help="generate expected failures based on last result and save to default file location") - options = parser.parse_args(argv) - - logger = multiprocessing.log_to_stderr() - logger.setLevel(options.log_level) - - if options.features is not None and options.client: - print('Cannot specify both --features and --client ', file=sys.stderr) - return 1 - - # Allow multiple args separated with ',' for backward compatibility - server_match = list(chain(*[x.split(',') for x in options.server])) - client_match = list(chain(*[x.split(',') for x in options.client])) - - if options.update_failures or options.print_failures: - dire = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE) if options.features is not None else TEST_DIR - res = crossrunner.generate_known_failures( - dire, options.update_failures == 'overwrite', - options.update_failures, options.print_failures) - elif options.features is not None: - features = options.features or ['.*'] - res = run_feature_tests(server_match, features, options.jobs, options.skip_known_failures, options.retry_count) - else: - res = run_cross_tests(server_match, client_match, options.jobs, options.skip_known_failures, options.retry_count) - return 0 if res else 1 + parser = argparse.ArgumentParser() + parser.add_argument('--server', default='', nargs='*', + help='list of servers to test') + parser.add_argument('--client', default='', nargs='*', + help='list of clients to test') + parser.add_argument('-F', '--features', nargs='*', default=None, + help='run server feature tests instead of cross language tests') + parser.add_argument('-s', '--skip-known-failures', action='store_true', dest='skip_known_failures', + help='do not execute tests that are known to fail') + parser.add_argument('-r', '--retry-count', type=int, + default=0, help='maximum retry on failure') + parser.add_argument('-j', '--jobs', type=int, + default=default_concurrenty(), + help='number of concurrent test executions') + + g = parser.add_argument_group(title='Advanced') + g.add_argument('-v', '--verbose', action='store_const', + dest='log_level', const=logging.DEBUG, default=logging.WARNING, + help='show debug output for test runner') + g.add_argument('-P', '--print-expected-failures', choices=['merge', 'overwrite'], + dest='print_failures', + help="generate expected failures based on last result and print to stdout") + g.add_argument('-U', '--update-expected-failures', choices=['merge', 'overwrite'], + dest='update_failures', + help="generate expected failures based on last result and save to default file location") + options = parser.parse_args(argv) + + logger = multiprocessing.log_to_stderr() + logger.setLevel(options.log_level) + + if options.features is not None and options.client: + print('Cannot specify both --features and --client ', file=sys.stderr) + return 1 + + # Allow multiple args separated with ',' for backward compatibility + server_match = list(chain(*[x.split(',') for x in options.server])) + client_match = list(chain(*[x.split(',') for x in options.client])) + + if options.update_failures or options.print_failures: + dire = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE) if options.features is not None else TEST_DIR + res = crossrunner.generate_known_failures( + dire, options.update_failures == 'overwrite', + options.update_failures, options.print_failures) + elif options.features is not None: + features = options.features or ['.*'] + res = run_feature_tests(server_match, features, options.jobs, options.skip_known_failures, options.retry_count) + else: + res = run_cross_tests(server_match, client_match, options.jobs, options.skip_known_failures, options.retry_count) + return 0 if res else 1 if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) + sys.exit(main(sys.argv[1:])) diff --git a/tutorial/php/runserver.py b/tutorial/php/runserver.py index ae29fed9c..077daa102 100755 --- a/tutorial/php/runserver.py +++ b/tutorial/php/runserver.py @@ -26,7 +26,8 @@ import CGIHTTPServer # chdir(2) into the tutorial directory. os.chdir(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + class Handler(CGIHTTPServer.CGIHTTPRequestHandler): - cgi_directories = ['/php'] + cgi_directories = ['/php'] BaseHTTPServer.HTTPServer(('', 8080), Handler).serve_forever() |