diff options
Diffstat (limited to 'lib/cpp/src')
78 files changed, 18900 insertions, 0 deletions
diff --git a/lib/cpp/src/TLogging.h b/lib/cpp/src/TLogging.h new file mode 100644 index 000000000..2df82dd7e --- /dev/null +++ b/lib/cpp/src/TLogging.h @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TLOGGING_H_ +#define _THRIFT_TLOGGING_H_ 1 + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +/** + * Contains utility macros for debugging and logging. + * + */ + +#ifndef HAVE_CLOCK_GETTIME +#include <time.h> +#else +#include <sys/time.h> +#endif + +#ifdef HAVE_STDINT_H +#include <stdint.h> +#endif + +/** + * T_GLOBAL_DEBUGGING_LEVEL = 0: all debugging turned off, debug macros undefined + * T_GLOBAL_DEBUGGING_LEVEL = 1: all debugging turned on + */ +#define T_GLOBAL_DEBUGGING_LEVEL 0 + + +/** + * T_GLOBAL_LOGGING_LEVEL = 0: all logging turned off, logging macros undefined + * T_GLOBAL_LOGGING_LEVEL = 1: all logging turned on + */ +#define T_GLOBAL_LOGGING_LEVEL 1 + + +/** + * Standard wrapper around fprintf what will prefix the file name and line + * number to the line. Uses T_GLOBAL_DEBUGGING_LEVEL to control whether it is + * turned on or off. + * + * @param format_string + */ +#if T_GLOBAL_DEBUGGING_LEVEL > 0 + #define T_DEBUG(format_string,...) \ + if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \ + fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \ + } +#else + #define T_DEBUG(format_string,...) +#endif + + +/** + * analagous to T_DEBUG but also prints the time + * + * @param string format_string input: printf style format string + */ +#if T_GLOBAL_DEBUGGING_LEVEL > 0 + #define T_DEBUG_T(format_string,...) \ + { \ + if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + } \ + } +#else + #define T_DEBUG_T(format_string,...) +#endif + + +/** + * analagous to T_DEBUG but uses input level to determine whether or not the string + * should be logged. + * + * @param int level: specified debug level + * @param string format_string input: format string + */ +#define T_DEBUG_L(level, format_string,...) \ + if ((level) > 0) { \ + fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \ + } + + +/** + * Explicit error logging. Prints time, file name and line number + * + * @param string format_string input: printf style format string + */ +#define T_ERROR(format_string,...) \ + { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] ERROR: " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + } + + +/** + * Analagous to T_ERROR, additionally aborting the process. + * WARNING: macro calls abort(), ending program execution + * + * @param string format_string input: printf style format string + */ +#define T_ERROR_ABORT(format_string,...) \ + { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] ERROR: Going to abort " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + exit(1); \ + } + + +/** + * Log input message + * + * @param string format_string input: printf style format string + */ +#if T_GLOBAL_LOGGING_LEVEL > 0 + #define T_LOG_OPER(format_string,...) \ + { \ + if (T_GLOBAL_LOGGING_LEVEL > 0) { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s] " #format_string " \n", dbgtime,##__VA_ARGS__); \ + } \ + } +#else + #define T_LOG_OPER(format_string,...) +#endif + +#endif // #ifndef _THRIFT_TLOGGING_H_ diff --git a/lib/cpp/src/TProcessor.h b/lib/cpp/src/TProcessor.h new file mode 100644 index 000000000..f2d5279a2 --- /dev/null +++ b/lib/cpp/src/TProcessor.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TPROCESSOR_H_ +#define _THRIFT_TPROCESSOR_H_ 1 + +#include <string> +#include <protocol/TProtocol.h> +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { + +/** + * A processor is a generic object that acts upon two streams of data, one + * an input and the other an output. The definition of this object is loose, + * though the typical case is for some sort of server that either generates + * responses to an input stream or forwards data from one pipe onto another. + * + */ +class TProcessor { + public: + virtual ~TProcessor() {} + + virtual bool process(boost::shared_ptr<protocol::TProtocol> in, + boost::shared_ptr<protocol::TProtocol> out) = 0; + + bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> io) { + return process(io, io); + } + + protected: + TProcessor() {} +}; + +}} // apache::thrift + +#endif // #ifndef _THRIFT_PROCESSOR_H_ diff --git a/lib/cpp/src/TReflectionLocal.h b/lib/cpp/src/TReflectionLocal.h new file mode 100644 index 000000000..e83e47530 --- /dev/null +++ b/lib/cpp/src/TReflectionLocal.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TREFLECTIONLOCAL_H_ +#define _THRIFT_TREFLECTIONLOCAL_H_ 1 + +#include <stdint.h> +#include <cstring> +#include <protocol/TProtocol.h> + +/** + * Local Reflection is a blanket term referring to the the structure + * and generation of this particular representation of Thrift types. + * (It is called local because it cannot be serialized by Thrift). + * + */ + +namespace apache { namespace thrift { namespace reflection { namespace local { + +using apache::thrift::protocol::TType; + +// We include this many bytes of the structure's fingerprint when serializing +// a top-level structure. Long enough to make collisions unlikely, short +// enough to not significantly affect the amount of memory used. +const int FP_PREFIX_LEN = 4; + +struct FieldMeta { + int16_t tag; + bool is_optional; +}; + +struct TypeSpec { + TType ttype; + uint8_t fp_prefix[FP_PREFIX_LEN]; + + // Use an anonymous union here so we can fit two TypeSpecs in one cache line. + union { + struct { + // Use parallel arrays here for denser packing (of the arrays). + FieldMeta* metas; + TypeSpec** specs; + } tstruct; + struct { + TypeSpec *subtype1; + TypeSpec *subtype2; + } tcontainer; + }; + + // Static initialization of unions isn't really possible, + // so take the plunge and use constructors. + // Hopefully they'll be evaluated at compile time. + + TypeSpec(TType ttype) : ttype(ttype) { + std::memset(fp_prefix, 0, FP_PREFIX_LEN); + } + + TypeSpec(TType ttype, + const uint8_t* fingerprint, + FieldMeta* metas, + TypeSpec** specs) : + ttype(ttype) + { + std::memcpy(fp_prefix, fingerprint, FP_PREFIX_LEN); + tstruct.metas = metas; + tstruct.specs = specs; + } + + TypeSpec(TType ttype, TypeSpec* subtype1, TypeSpec* subtype2) : + ttype(ttype) + { + std::memset(fp_prefix, 0, FP_PREFIX_LEN); + tcontainer.subtype1 = subtype1; + tcontainer.subtype2 = subtype2; + } + +}; + +}}}} // apache::thrift::reflection::local + +#endif // #ifndef _THRIFT_TREFLECTIONLOCAL_H_ diff --git a/lib/cpp/src/Thrift.cpp b/lib/cpp/src/Thrift.cpp new file mode 100644 index 000000000..ed99205b8 --- /dev/null +++ b/lib/cpp/src/Thrift.cpp @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Thrift.h> +#include <cstring> +#include <boost/lexical_cast.hpp> +#include <protocol/TProtocol.h> +#include <stdarg.h> +#include <stdio.h> + +namespace apache { namespace thrift { + +TOutput GlobalOutput; + +void TOutput::printf(const char *message, ...) { + // Try to reduce heap usage, even if printf is called rarely. + static const int STACK_BUF_SIZE = 256; + char stack_buf[STACK_BUF_SIZE]; + va_list ap; + + va_start(ap, message); + int need = vsnprintf(stack_buf, STACK_BUF_SIZE, message, ap); + va_end(ap); + + if (need < STACK_BUF_SIZE) { + f_(stack_buf); + return; + } + + char *heap_buf = (char*)malloc((need+1) * sizeof(char)); + if (heap_buf == NULL) { + // Malloc failed. We might as well print the stack buffer. + f_(stack_buf); + return; + } + + va_start(ap, message); + int rval = vsnprintf(heap_buf, need+1, message, ap); + va_end(ap); + // TODO(shigin): inform user + if (rval != -1) { + f_(heap_buf); + } + free(heap_buf); +} + +void TOutput::perror(const char *message, int errno_copy) { + std::string out = message + strerror_s(errno_copy); + f_(out.c_str()); +} + +std::string TOutput::strerror_s(int errno_copy) { +#ifndef HAVE_STRERROR_R + return "errno = " + boost::lexical_cast<std::string>(errno_copy); +#else // HAVE_STRERROR_R + + char b_errbuf[1024] = { '\0' }; +#ifdef STRERROR_R_CHAR_P + char *b_error = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf)); +#else + char *b_error = b_errbuf; + int rv = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf)); + if (rv == -1) { + // strerror_r failed. omgwtfbbq. + return "XSI-compliant strerror_r() failed with errno = " + + boost::lexical_cast<std::string>(errno_copy); + } +#endif + // Can anyone prove that explicit cast is probably not necessary + // to ensure that the string object is constructed before + // b_error becomes invalid? + return std::string(b_error); + +#endif // HAVE_STRERROR_R +} + +uint32_t TApplicationException::read(apache::thrift::protocol::TProtocol* iprot) { + uint32_t xfer = 0; + std::string fname; + apache::thrift::protocol::TType ftype; + int16_t fid; + + xfer += iprot->readStructBegin(fname); + + while (true) { + xfer += iprot->readFieldBegin(fname, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + switch (fid) { + case 1: + if (ftype == apache::thrift::protocol::T_STRING) { + xfer += iprot->readString(message_); + } else { + xfer += iprot->skip(ftype); + } + break; + case 2: + if (ftype == apache::thrift::protocol::T_I32) { + int32_t type; + xfer += iprot->readI32(type); + type_ = (TApplicationExceptionType)type; + } else { + xfer += iprot->skip(ftype); + } + break; + default: + xfer += iprot->skip(ftype); + break; + } + xfer += iprot->readFieldEnd(); + } + + xfer += iprot->readStructEnd(); + return xfer; +} + +uint32_t TApplicationException::write(apache::thrift::protocol::TProtocol* oprot) const { + uint32_t xfer = 0; + xfer += oprot->writeStructBegin("TApplicationException"); + xfer += oprot->writeFieldBegin("message", apache::thrift::protocol::T_STRING, 1); + xfer += oprot->writeString(message_); + xfer += oprot->writeFieldEnd(); + xfer += oprot->writeFieldBegin("type", apache::thrift::protocol::T_I32, 2); + xfer += oprot->writeI32(type_); + xfer += oprot->writeFieldEnd(); + xfer += oprot->writeFieldStop(); + xfer += oprot->writeStructEnd(); + return xfer; +} + +}} // apache::thrift diff --git a/lib/cpp/src/Thrift.h b/lib/cpp/src/Thrift.h new file mode 100644 index 000000000..26d2b0fcd --- /dev/null +++ b/lib/cpp/src/Thrift.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_THRIFT_H_ +#define _THRIFT_THRIFT_H_ 1 + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif +#include <stdio.h> + +#include <netinet/in.h> +#ifdef HAVE_INTTYPES_H +#include <inttypes.h> +#endif +#include <string> +#include <map> +#include <list> +#include <set> +#include <vector> +#include <exception> + +#include "TLogging.h" + +namespace apache { namespace thrift { + +class TOutput { + public: + TOutput() : f_(&errorTimeWrapper) {} + + inline void setOutputFunction(void (*function)(const char *)){ + f_ = function; + } + + inline void operator()(const char *message){ + f_(message); + } + + // It is important to have a const char* overload here instead of + // just the string version, otherwise errno could be corrupted + // if there is some problem allocating memory when constructing + // the string. + void perror(const char *message, int errno_copy); + inline void perror(const std::string &message, int errno_copy) { + perror(message.c_str(), errno_copy); + } + + void printf(const char *message, ...); + + inline static void errorTimeWrapper(const char* msg) { + time_t now; + char dbgtime[25]; + time(&now); + ctime_r(&now, dbgtime); + dbgtime[24] = 0; + fprintf(stderr, "Thrift: %s %s\n", dbgtime, msg); + } + + /** Just like strerror_r but returns a C++ string object. */ + static std::string strerror_s(int errno_copy); + + private: + void (*f_)(const char *); +}; + +extern TOutput GlobalOutput; + +namespace protocol { + class TProtocol; +} + +class TException : public std::exception { + public: + TException() {} + + TException(const std::string& message) : + message_(message) {} + + virtual ~TException() throw() {} + + virtual const char* what() const throw() { + if (message_.empty()) { + return "Default TException."; + } else { + return message_.c_str(); + } + } + + protected: + std::string message_; + +}; + +class TApplicationException : public TException { + public: + + /** + * Error codes for the various types of exceptions. + */ + enum TApplicationExceptionType + { UNKNOWN = 0 + , UNKNOWN_METHOD = 1 + , INVALID_MESSAGE_TYPE = 2 + , WRONG_METHOD_NAME = 3 + , BAD_SEQUENCE_ID = 4 + , MISSING_RESULT = 5 + }; + + TApplicationException() : + TException(), + type_(UNKNOWN) {} + + TApplicationException(TApplicationExceptionType type) : + TException(), + type_(type) {} + + TApplicationException(const std::string& message) : + TException(message), + type_(UNKNOWN) {} + + TApplicationException(TApplicationExceptionType type, + const std::string& message) : + TException(message), + type_(type) {} + + virtual ~TApplicationException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TApplicationExceptionType getType() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TApplicationException: Unknown application exception"; + case UNKNOWN_METHOD : return "TApplicationException: Unknown method"; + case INVALID_MESSAGE_TYPE : return "TApplicationException: Invalid message type"; + case WRONG_METHOD_NAME : return "TApplicationException: Wrong method name"; + case BAD_SEQUENCE_ID : return "TApplicationException: Bad sequence identifier"; + case MISSING_RESULT : return "TApplicationException: Missing result"; + default : return "TApplicationException: (Invalid exception type)"; + }; + } else { + return message_.c_str(); + } + } + + uint32_t read(protocol::TProtocol* iprot); + uint32_t write(protocol::TProtocol* oprot) const; + + protected: + /** + * Error code + */ + TApplicationExceptionType type_; + +}; + + +// Forward declare this structure used by TDenseProtocol +namespace reflection { namespace local { +struct TypeSpec; +}} + + +}} // apache::thrift + +#endif // #ifndef _THRIFT_THRIFT_H_ diff --git a/lib/cpp/src/concurrency/Exception.h b/lib/cpp/src/concurrency/Exception.h new file mode 100644 index 000000000..ec4662976 --- /dev/null +++ b/lib/cpp/src/concurrency/Exception.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_ +#define _THRIFT_CONCURRENCY_EXCEPTION_H_ 1 + +#include <exception> +#include <Thrift.h> + +namespace apache { namespace thrift { namespace concurrency { + +class NoSuchTaskException : public apache::thrift::TException {}; + +class UncancellableTaskException : public apache::thrift::TException {}; + +class InvalidArgumentException : public apache::thrift::TException {}; + +class IllegalStateException : public apache::thrift::TException {}; + +class TimedOutException : public apache::thrift::TException { +public: + TimedOutException():TException("TimedOutException"){}; + TimedOutException(const std::string& message ) : + TException(message) {} +}; + +class TooManyPendingTasksException : public apache::thrift::TException { +public: + TooManyPendingTasksException():TException("TooManyPendingTasksException"){}; + TooManyPendingTasksException(const std::string& message ) : + TException(message) {} +}; + +class SystemResourceException : public apache::thrift::TException { +public: + SystemResourceException() {} + + SystemResourceException(const std::string& message) : + TException(message) {} +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_ diff --git a/lib/cpp/src/concurrency/FunctionRunner.h b/lib/cpp/src/concurrency/FunctionRunner.h new file mode 100644 index 000000000..221692767 --- /dev/null +++ b/lib/cpp/src/concurrency/FunctionRunner.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H +#define _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H 1 + +#include <tr1/functional> +#include "thrift/lib/cpp/concurrency/Thread.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Convenient implementation of Runnable that will execute arbitrary callbacks. + * Interfaces are provided to accept both a generic 'void(void)' callback, and + * a 'void* (void*)' pthread_create-style callback. + * + * Example use: + * void* my_thread_main(void* arg); + * shared_ptr<ThreadFactory> factory = ...; + * shared_ptr<Thread> thread = + * factory->newThread(shared_ptr<FunctionRunner>( + * new FunctionRunner(my_thread_main, some_argument))); + * thread->start(); + * + * + */ + +class FunctionRunner : public Runnable { + public: + // This is the type of callback 'pthread_create()' expects. + typedef void* (*PthreadFuncPtr)(void *arg); + // This a fully-generic void(void) callback for custom bindings. + typedef std::tr1::function<void()> VoidFunc; + + /** + * Given a 'pthread_create' style callback, this FunctionRunner will + * execute the given callback. Note that the 'void*' return value is ignored. + */ + FunctionRunner(PthreadFuncPtr func, void* arg) + : func_(std::tr1::bind(func, arg)) + { } + + /** + * Given a generic callback, this FunctionRunner will execute it. + */ + FunctionRunner(const VoidFunc& cob) + : func_(cob) + { } + + + void run() { + func_(); + } + + private: + VoidFunc func_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H diff --git a/lib/cpp/src/concurrency/Monitor.cpp b/lib/cpp/src/concurrency/Monitor.cpp new file mode 100644 index 000000000..2055caa95 --- /dev/null +++ b/lib/cpp/src/concurrency/Monitor.cpp @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Monitor.h" +#include "Exception.h" +#include "Util.h" + +#include <assert.h> +#include <errno.h> + +#include <iostream> + +#include <pthread.h> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Monitor implementation using the POSIX pthread library + * + * @version $Id:$ + */ +class Monitor::Impl { + + public: + + Impl() : + mutexInitialized_(false), + condInitialized_(false) { + + if (pthread_mutex_init(&pthread_mutex_, NULL) == 0) { + mutexInitialized_ = true; + + if (pthread_cond_init(&pthread_cond_, NULL) == 0) { + condInitialized_ = true; + } + } + + if (!mutexInitialized_ || !condInitialized_) { + cleanup(); + throw SystemResourceException(); + } + } + + ~Impl() { cleanup(); } + + void lock() const { pthread_mutex_lock(&pthread_mutex_); } + + void unlock() const { pthread_mutex_unlock(&pthread_mutex_); } + + void wait(int64_t timeout) const { + + // XXX Need to assert that caller owns mutex + assert(timeout >= 0LL); + if (timeout == 0LL) { + int iret = pthread_cond_wait(&pthread_cond_, &pthread_mutex_); + assert(iret == 0); + } else { + struct timespec abstime; + int64_t now = Util::currentTime(); + Util::toTimespec(abstime, now + timeout); + int result = pthread_cond_timedwait(&pthread_cond_, + &pthread_mutex_, + &abstime); + if (result == ETIMEDOUT) { + // pthread_cond_timedwait has been observed to return early on + // various platforms, so comment out this assert. + //assert(Util::currentTime() >= (now + timeout)); + throw TimedOutException(); + } + } + } + + void notify() { + // XXX Need to assert that caller owns mutex + int iret = pthread_cond_signal(&pthread_cond_); + assert(iret == 0); + } + + void notifyAll() { + // XXX Need to assert that caller owns mutex + int iret = pthread_cond_broadcast(&pthread_cond_); + assert(iret == 0); + } + + private: + + void cleanup() { + if (mutexInitialized_) { + mutexInitialized_ = false; + int iret = pthread_mutex_destroy(&pthread_mutex_); + assert(iret == 0); + } + + if (condInitialized_) { + condInitialized_ = false; + int iret = pthread_cond_destroy(&pthread_cond_); + assert(iret == 0); + } + } + + mutable pthread_mutex_t pthread_mutex_; + mutable bool mutexInitialized_; + mutable pthread_cond_t pthread_cond_; + mutable bool condInitialized_; +}; + +Monitor::Monitor() : impl_(new Monitor::Impl()) {} + +Monitor::~Monitor() { delete impl_; } + +void Monitor::lock() const { impl_->lock(); } + +void Monitor::unlock() const { impl_->unlock(); } + +void Monitor::wait(int64_t timeout) const { impl_->wait(timeout); } + +void Monitor::notify() const { impl_->notify(); } + +void Monitor::notifyAll() const { impl_->notifyAll(); } + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/Monitor.h b/lib/cpp/src/concurrency/Monitor.h new file mode 100644 index 000000000..234bf3269 --- /dev/null +++ b/lib/cpp/src/concurrency/Monitor.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_MONITOR_H_ +#define _THRIFT_CONCURRENCY_MONITOR_H_ 1 + +#include "Exception.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A monitor is a combination mutex and condition-event. Waiting and + * notifying condition events requires that the caller own the mutex. Mutex + * lock and unlock operations can be performed independently of condition + * events. This is more or less analogous to java.lang.Object multi-thread + * operations + * + * Note that all methods are const. Monitors implement logical constness, not + * bit constness. This allows const methods to call monitor methods without + * needing to cast away constness or change to non-const signatures. + * + * @version $Id:$ + */ +class Monitor { + + public: + + Monitor(); + + virtual ~Monitor(); + + virtual void lock() const; + + virtual void unlock() const; + + virtual void wait(int64_t timeout=0LL) const; + + virtual void notify() const; + + virtual void notifyAll() const; + + private: + + class Impl; + + Impl* impl_; +}; + +class Synchronized { + public: + + Synchronized(const Monitor& value) : + monitor_(value) { + monitor_.lock(); + } + + ~Synchronized() { + monitor_.unlock(); + } + + private: + const Monitor& monitor_; +}; + + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_MONITOR_H_ diff --git a/lib/cpp/src/concurrency/Mutex.cpp b/lib/cpp/src/concurrency/Mutex.cpp new file mode 100644 index 000000000..045dbdfe2 --- /dev/null +++ b/lib/cpp/src/concurrency/Mutex.cpp @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Mutex.h" + +#include <assert.h> +#include <pthread.h> + +using boost::shared_ptr; + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Implementation of Mutex class using POSIX mutex + * + * @version $Id:$ + */ +class Mutex::impl { + public: + impl(Initializer init) : initialized_(false) { + init(&pthread_mutex_); + initialized_ = true; + } + + ~impl() { + if (initialized_) { + initialized_ = false; + int ret = pthread_mutex_destroy(&pthread_mutex_); + assert(ret == 0); + } + } + + void lock() const { pthread_mutex_lock(&pthread_mutex_); } + + bool trylock() const { return (0 == pthread_mutex_trylock(&pthread_mutex_)); } + + void unlock() const { pthread_mutex_unlock(&pthread_mutex_); } + + private: + mutable pthread_mutex_t pthread_mutex_; + mutable bool initialized_; +}; + +Mutex::Mutex(Initializer init) : impl_(new Mutex::impl(init)) {} + +void Mutex::lock() const { impl_->lock(); } + +bool Mutex::trylock() const { return impl_->trylock(); } + +void Mutex::unlock() const { impl_->unlock(); } + +void Mutex::DEFAULT_INITIALIZER(void* arg) { + pthread_mutex_t* pthread_mutex = (pthread_mutex_t*)arg; + int ret = pthread_mutex_init(pthread_mutex, NULL); + assert(ret == 0); +} + +static void init_with_kind(pthread_mutex_t* mutex, int kind) { + pthread_mutexattr_t mutexattr; + int ret = pthread_mutexattr_init(&mutexattr); + assert(ret == 0); + + // Apparently, this can fail. Should we really be aborting? + ret = pthread_mutexattr_settype(&mutexattr, kind); + assert(ret == 0); + + ret = pthread_mutex_init(mutex, &mutexattr); + assert(ret == 0); + + ret = pthread_mutexattr_destroy(&mutexattr); + assert(ret == 0); +} + +#ifdef PTHREAD_ADAPTIVE_MUTEX_INITIALIZER_NP +void Mutex::ADAPTIVE_INITIALIZER(void* arg) { + // From mysql source: mysys/my_thr_init.c + // Set mutex type to "fast" a.k.a "adaptive" + // + // In this case the thread may steal the mutex from some other thread + // that is waiting for the same mutex. This will save us some + // context switches but may cause a thread to 'starve forever' while + // waiting for the mutex (not likely if the code within the mutex is + // short). + init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_ADAPTIVE_NP); +} +#endif + +#ifdef PTHREAD_RECURSIVE_MUTEX_INITIALIZER_NP +void Mutex::RECURSIVE_INITIALIZER(void* arg) { + init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_RECURSIVE_NP); +} +#endif + + +/** + * Implementation of ReadWriteMutex class using POSIX rw lock + * + * @version $Id:$ + */ +class ReadWriteMutex::impl { +public: + impl() : initialized_(false) { + int ret = pthread_rwlock_init(&rw_lock_, NULL); + assert(ret == 0); + initialized_ = true; + } + + ~impl() { + if(initialized_) { + initialized_ = false; + int ret = pthread_rwlock_destroy(&rw_lock_); + assert(ret == 0); + } + } + + void acquireRead() const { pthread_rwlock_rdlock(&rw_lock_); } + + void acquireWrite() const { pthread_rwlock_wrlock(&rw_lock_); } + + bool attemptRead() const { return pthread_rwlock_tryrdlock(&rw_lock_); } + + bool attemptWrite() const { return pthread_rwlock_trywrlock(&rw_lock_); } + + void release() const { pthread_rwlock_unlock(&rw_lock_); } + +private: + mutable pthread_rwlock_t rw_lock_; + mutable bool initialized_; +}; + +ReadWriteMutex::ReadWriteMutex() : impl_(new ReadWriteMutex::impl()) {} + +void ReadWriteMutex::acquireRead() const { impl_->acquireRead(); } + +void ReadWriteMutex::acquireWrite() const { impl_->acquireWrite(); } + +bool ReadWriteMutex::attemptRead() const { return impl_->attemptRead(); } + +bool ReadWriteMutex::attemptWrite() const { return impl_->attemptWrite(); } + +void ReadWriteMutex::release() const { impl_->release(); } + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/Mutex.h b/lib/cpp/src/concurrency/Mutex.h new file mode 100644 index 000000000..884412bea --- /dev/null +++ b/lib/cpp/src/concurrency/Mutex.h @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_MUTEX_H_ +#define _THRIFT_CONCURRENCY_MUTEX_H_ 1 + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A simple mutex class + * + * @version $Id:$ + */ +class Mutex { + public: + typedef void (*Initializer)(void*); + + Mutex(Initializer init = DEFAULT_INITIALIZER); + virtual ~Mutex() {} + virtual void lock() const; + virtual bool trylock() const; + virtual void unlock() const; + + static void DEFAULT_INITIALIZER(void*); + static void ADAPTIVE_INITIALIZER(void*); + static void RECURSIVE_INITIALIZER(void*); + + private: + + class impl; + boost::shared_ptr<impl> impl_; +}; + +class ReadWriteMutex { +public: + ReadWriteMutex(); + virtual ~ReadWriteMutex() {} + + // these get the lock and block until it is done successfully + virtual void acquireRead() const; + virtual void acquireWrite() const; + + // these attempt to get the lock, returning false immediately if they fail + virtual bool attemptRead() const; + virtual bool attemptWrite() const; + + // this releases both read and write locks + virtual void release() const; + +private: + + class impl; + boost::shared_ptr<impl> impl_; +}; + +class Guard { + public: + Guard(const Mutex& value) : mutex_(value) { + mutex_.lock(); + } + ~Guard() { + mutex_.unlock(); + } + + private: + const Mutex& mutex_; +}; + +class RWGuard { + public: + RWGuard(const ReadWriteMutex& value, bool write = 0) : rw_mutex_(value) { + if (write) { + rw_mutex_.acquireWrite(); + } else { + rw_mutex_.acquireRead(); + } + } + ~RWGuard() { + rw_mutex_.release(); + } + private: + const ReadWriteMutex& rw_mutex_; +}; + + +// A little hack to prevent someone from trying to do "Guard(m);" +// Sorry for polluting the global namespace, but I think it's worth it. +#define Guard(m) incorrect_use_of_Guard(m) +#define RWGuard(m) incorrect_use_of_RWGuard(m) + + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_MUTEX_H_ diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.cpp b/lib/cpp/src/concurrency/PosixThreadFactory.cpp new file mode 100644 index 000000000..e48dce39e --- /dev/null +++ b/lib/cpp/src/concurrency/PosixThreadFactory.cpp @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "PosixThreadFactory.h" +#include "Exception.h" + +#if GOOGLE_PERFTOOLS_REGISTER_THREAD +# include <google/profiler.h> +#endif + +#include <assert.h> +#include <pthread.h> + +#include <iostream> + +#include <boost/weak_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; +using boost::weak_ptr; + +/** + * The POSIX thread class. + * + * @version $Id:$ + */ +class PthreadThread: public Thread { + public: + + enum STATE { + uninitialized, + starting, + started, + stopping, + stopped + }; + + static const int MB = 1024 * 1024; + + static void* threadMain(void* arg); + + private: + pthread_t pthread_; + STATE state_; + int policy_; + int priority_; + int stackSize_; + weak_ptr<PthreadThread> self_; + bool detached_; + + public: + + PthreadThread(int policy, int priority, int stackSize, bool detached, shared_ptr<Runnable> runnable) : + pthread_(0), + state_(uninitialized), + policy_(policy), + priority_(priority), + stackSize_(stackSize), + detached_(detached) { + + this->Thread::runnable(runnable); + } + + ~PthreadThread() { + /* Nothing references this thread, if is is not detached, do a join + now, otherwise the thread-id and, possibly, other resources will + be leaked. */ + if(!detached_) { + try { + join(); + } catch(...) { + // We're really hosed. + } + } + } + + void start() { + if (state_ != uninitialized) { + return; + } + + pthread_attr_t thread_attr; + if (pthread_attr_init(&thread_attr) != 0) { + throw SystemResourceException("pthread_attr_init failed"); + } + + if(pthread_attr_setdetachstate(&thread_attr, + detached_ ? + PTHREAD_CREATE_DETACHED : + PTHREAD_CREATE_JOINABLE) != 0) { + throw SystemResourceException("pthread_attr_setdetachstate failed"); + } + + // Set thread stack size + if (pthread_attr_setstacksize(&thread_attr, MB * stackSize_) != 0) { + throw SystemResourceException("pthread_attr_setstacksize failed"); + } + + // Set thread policy + if (pthread_attr_setschedpolicy(&thread_attr, policy_) != 0) { + throw SystemResourceException("pthread_attr_setschedpolicy failed"); + } + + struct sched_param sched_param; + sched_param.sched_priority = priority_; + + // Set thread priority + if (pthread_attr_setschedparam(&thread_attr, &sched_param) != 0) { + throw SystemResourceException("pthread_attr_setschedparam failed"); + } + + // Create reference + shared_ptr<PthreadThread>* selfRef = new shared_ptr<PthreadThread>(); + *selfRef = self_.lock(); + + state_ = starting; + + if (pthread_create(&pthread_, &thread_attr, threadMain, (void*)selfRef) != 0) { + throw SystemResourceException("pthread_create failed"); + } + } + + void join() { + if (!detached_ && state_ != uninitialized) { + void* ignore; + /* XXX + If join fails it is most likely due to the fact + that the last reference was the thread itself and cannot + join. This results in leaked threads and will eventually + cause the process to run out of thread resources. + We're beyond the point of throwing an exception. Not clear how + best to handle this. */ + detached_ = pthread_join(pthread_, &ignore) == 0; + } + } + + Thread::id_t getId() { + return (Thread::id_t)pthread_; + } + + shared_ptr<Runnable> runnable() const { return Thread::runnable(); } + + void runnable(shared_ptr<Runnable> value) { Thread::runnable(value); } + + void weakRef(shared_ptr<PthreadThread> self) { + assert(self.get() == this); + self_ = weak_ptr<PthreadThread>(self); + } +}; + +void* PthreadThread::threadMain(void* arg) { + shared_ptr<PthreadThread> thread = *(shared_ptr<PthreadThread>*)arg; + delete reinterpret_cast<shared_ptr<PthreadThread>*>(arg); + + if (thread == NULL) { + return (void*)0; + } + + if (thread->state_ != starting) { + return (void*)0; + } + +#if GOOGLE_PERFTOOLS_REGISTER_THREAD + ProfilerRegisterThread(); +#endif + + thread->state_ = starting; + thread->runnable()->run(); + if (thread->state_ != stopping && thread->state_ != stopped) { + thread->state_ = stopping; + } + + return (void*)0; +} + +/** + * POSIX Thread factory implementation + */ +class PosixThreadFactory::Impl { + + private: + POLICY policy_; + PRIORITY priority_; + int stackSize_; + bool detached_; + + /** + * Converts generic posix thread schedule policy enums into pthread + * API values. + */ + static int toPthreadPolicy(POLICY policy) { + switch (policy) { + case OTHER: + return SCHED_OTHER; + case FIFO: + return SCHED_FIFO; + case ROUND_ROBIN: + return SCHED_RR; + } + return SCHED_OTHER; + } + + /** + * Converts relative thread priorities to absolute value based on posix + * thread scheduler policy + * + * The idea is simply to divide up the priority range for the given policy + * into the correpsonding relative priority level (lowest..highest) and + * then pro-rate accordingly. + */ + static int toPthreadPriority(POLICY policy, PRIORITY priority) { + int pthread_policy = toPthreadPolicy(policy); + int min_priority = sched_get_priority_min(pthread_policy); + int max_priority = sched_get_priority_max(pthread_policy); + int quanta = (HIGHEST - LOWEST) + 1; + float stepsperquanta = (max_priority - min_priority) / quanta; + + if (priority <= HIGHEST) { + return (int)(min_priority + stepsperquanta * priority); + } else { + // should never get here for priority increments. + assert(false); + return (int)(min_priority + stepsperquanta * NORMAL); + } + } + + public: + + Impl(POLICY policy, PRIORITY priority, int stackSize, bool detached) : + policy_(policy), + priority_(priority), + stackSize_(stackSize), + detached_(detached) {} + + /** + * Creates a new POSIX thread to run the runnable object + * + * @param runnable A runnable object + */ + shared_ptr<Thread> newThread(shared_ptr<Runnable> runnable) const { + shared_ptr<PthreadThread> result = shared_ptr<PthreadThread>(new PthreadThread(toPthreadPolicy(policy_), toPthreadPriority(policy_, priority_), stackSize_, detached_, runnable)); + result->weakRef(result); + runnable->thread(result); + return result; + } + + int getStackSize() const { return stackSize_; } + + void setStackSize(int value) { stackSize_ = value; } + + PRIORITY getPriority() const { return priority_; } + + /** + * Sets priority. + * + * XXX + * Need to handle incremental priorities properly. + */ + void setPriority(PRIORITY value) { priority_ = value; } + + bool isDetached() const { return detached_; } + + void setDetached(bool value) { detached_ = value; } + + Thread::id_t getCurrentThreadId() const { + // TODO(dreiss): Stop using C-style casts. + return (id_t)pthread_self(); + } + +}; + +PosixThreadFactory::PosixThreadFactory(POLICY policy, PRIORITY priority, int stackSize, bool detached) : + impl_(new PosixThreadFactory::Impl(policy, priority, stackSize, detached)) {} + +shared_ptr<Thread> PosixThreadFactory::newThread(shared_ptr<Runnable> runnable) const { return impl_->newThread(runnable); } + +int PosixThreadFactory::getStackSize() const { return impl_->getStackSize(); } + +void PosixThreadFactory::setStackSize(int value) { impl_->setStackSize(value); } + +PosixThreadFactory::PRIORITY PosixThreadFactory::getPriority() const { return impl_->getPriority(); } + +void PosixThreadFactory::setPriority(PosixThreadFactory::PRIORITY value) { impl_->setPriority(value); } + +bool PosixThreadFactory::isDetached() const { return impl_->isDetached(); } + +void PosixThreadFactory::setDetached(bool value) { impl_->setDetached(value); } + +Thread::id_t PosixThreadFactory::getCurrentThreadId() const { return impl_->getCurrentThreadId(); } + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.h b/lib/cpp/src/concurrency/PosixThreadFactory.h new file mode 100644 index 000000000..d6d83a3a1 --- /dev/null +++ b/lib/cpp/src/concurrency/PosixThreadFactory.h @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ +#define _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ 1 + +#include "Thread.h" + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A thread factory to create posix threads + * + * @version $Id:$ + */ +class PosixThreadFactory : public ThreadFactory { + + public: + + /** + * POSIX Thread scheduler policies + */ + enum POLICY { + OTHER, + FIFO, + ROUND_ROBIN + }; + + /** + * POSIX Thread scheduler relative priorities, + * + * Absolute priority is determined by scheduler policy and OS. This + * enumeration specifies relative priorities such that one can specify a + * priority withing a giving scheduler policy without knowing the absolute + * value of the priority. + */ + enum PRIORITY { + LOWEST = 0, + LOWER = 1, + LOW = 2, + NORMAL = 3, + HIGH = 4, + HIGHER = 5, + HIGHEST = 6, + INCREMENT = 7, + DECREMENT = 8 + }; + + /** + * Posix thread (pthread) factory. All threads created by a factory are reference-counted + * via boost::shared_ptr and boost::weak_ptr. The factory guarantees that threads and + * the Runnable tasks they host will be properly cleaned up once the last strong reference + * to both is given up. + * + * Threads are created with the specified policy, priority, stack-size and detachable-mode + * detached means the thread is free-running and will release all system resources the + * when it completes. A detachable thread is not joinable. The join method + * of a detachable thread will return immediately with no error. + * + * By default threads are not joinable. + */ + + PosixThreadFactory(POLICY policy=ROUND_ROBIN, PRIORITY priority=NORMAL, int stackSize=1, bool detached=true); + + // From ThreadFactory; + boost::shared_ptr<Thread> newThread(boost::shared_ptr<Runnable> runnable) const; + + // From ThreadFactory; + Thread::id_t getCurrentThreadId() const; + + /** + * Gets stack size for created threads + * + * @return int size in megabytes + */ + virtual int getStackSize() const; + + /** + * Sets stack size for created threads + * + * @param value size in megabytes + */ + virtual void setStackSize(int value); + + /** + * Gets priority relative to current policy + */ + virtual PRIORITY getPriority() const; + + /** + * Sets priority relative to current policy + */ + virtual void setPriority(PRIORITY priority); + + /** + * Sets detached mode of threads + */ + virtual void setDetached(bool detached); + + /** + * Gets current detached mode + */ + virtual bool isDetached() const; + + private: + class Impl; + boost::shared_ptr<Impl> impl_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ diff --git a/lib/cpp/src/concurrency/Thread.h b/lib/cpp/src/concurrency/Thread.h new file mode 100644 index 000000000..d4282adbc --- /dev/null +++ b/lib/cpp/src/concurrency/Thread.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_THREAD_H_ +#define _THRIFT_CONCURRENCY_THREAD_H_ 1 + +#include <stdint.h> +#include <boost/shared_ptr.hpp> +#include <boost/weak_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +class Thread; + +/** + * Minimal runnable class. More or less analogous to java.lang.Runnable. + * + * @version $Id:$ + */ +class Runnable { + + public: + virtual ~Runnable() {}; + virtual void run() = 0; + + /** + * Gets the thread object that is hosting this runnable object - can return + * an empty boost::shared pointer if no references remain on thet thread object + */ + virtual boost::shared_ptr<Thread> thread() { return thread_.lock(); } + + /** + * Sets the thread that is executing this object. This is only meant for + * use by concrete implementations of Thread. + */ + virtual void thread(boost::shared_ptr<Thread> value) { thread_ = value; } + + private: + boost::weak_ptr<Thread> thread_; +}; + +/** + * Minimal thread class. Returned by thread factory bound to a Runnable object + * and ready to start execution. More or less analogous to java.lang.Thread + * (minus all the thread group, priority, mode and other baggage, since that + * is difficult to abstract across platforms and is left for platform-specific + * ThreadFactory implemtations to deal with + * + * @see apache::thrift::concurrency::ThreadFactory) + */ +class Thread { + + public: + + typedef uint64_t id_t; + + virtual ~Thread() {}; + + /** + * Starts the thread. Does platform specific thread creation and + * configuration then invokes the run method of the Runnable object bound + * to this thread. + */ + virtual void start() = 0; + + /** + * Join this thread. Current thread blocks until this target thread + * completes. + */ + virtual void join() = 0; + + /** + * Gets the thread's platform-specific ID + */ + virtual id_t getId() = 0; + + /** + * Gets the runnable object this thread is hosting + */ + virtual boost::shared_ptr<Runnable> runnable() const { return _runnable; } + + protected: + virtual void runnable(boost::shared_ptr<Runnable> value) { _runnable = value; } + + private: + boost::shared_ptr<Runnable> _runnable; + +}; + +/** + * Factory to create platform-specific thread object and bind them to Runnable + * object for execution + */ +class ThreadFactory { + + public: + virtual ~ThreadFactory() {} + virtual boost::shared_ptr<Thread> newThread(boost::shared_ptr<Runnable> runnable) const = 0; + + /** Gets the current thread id or unknown_thread_id if the current thread is not a thrift thread */ + + static const Thread::id_t unknown_thread_id; + + virtual Thread::id_t getCurrentThreadId() const = 0; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_THREAD_H_ diff --git a/lib/cpp/src/concurrency/ThreadManager.cpp b/lib/cpp/src/concurrency/ThreadManager.cpp new file mode 100644 index 000000000..abfcf6e70 --- /dev/null +++ b/lib/cpp/src/concurrency/ThreadManager.cpp @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "ThreadManager.h" +#include "Exception.h" +#include "Monitor.h" + +#include <boost/shared_ptr.hpp> + +#include <assert.h> +#include <queue> +#include <set> + +#if defined(DEBUG) +#include <iostream> +#endif //defined(DEBUG) + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; +using boost::dynamic_pointer_cast; + +/** + * ThreadManager class + * + * This class manages a pool of threads. It uses a ThreadFactory to create + * threads. It never actually creates or destroys worker threads, rather + * it maintains statistics on number of idle threads, number of active threads, + * task backlog, and average wait and service times. + * + * @version $Id:$ + */ +class ThreadManager::Impl : public ThreadManager { + + public: + Impl() : + workerCount_(0), + workerMaxCount_(0), + idleCount_(0), + pendingTaskCountMax_(0), + state_(ThreadManager::UNINITIALIZED) {} + + ~Impl() { stop(); } + + void start(); + + void stop() { stopImpl(false); } + + void join() { stopImpl(true); } + + const ThreadManager::STATE state() const { + return state_; + } + + shared_ptr<ThreadFactory> threadFactory() const { + Synchronized s(monitor_); + return threadFactory_; + } + + void threadFactory(shared_ptr<ThreadFactory> value) { + Synchronized s(monitor_); + threadFactory_ = value; + } + + void addWorker(size_t value); + + void removeWorker(size_t value); + + size_t idleWorkerCount() const { + return idleCount_; + } + + size_t workerCount() const { + Synchronized s(monitor_); + return workerCount_; + } + + size_t pendingTaskCount() const { + Synchronized s(monitor_); + return tasks_.size(); + } + + size_t totalTaskCount() const { + Synchronized s(monitor_); + return tasks_.size() + workerCount_ - idleCount_; + } + + size_t pendingTaskCountMax() const { + Synchronized s(monitor_); + return pendingTaskCountMax_; + } + + void pendingTaskCountMax(const size_t value) { + Synchronized s(monitor_); + pendingTaskCountMax_ = value; + } + + bool canSleep(); + + void add(shared_ptr<Runnable> value, int64_t timeout); + + void remove(shared_ptr<Runnable> task); + +private: + void stopImpl(bool join); + + size_t workerCount_; + size_t workerMaxCount_; + size_t idleCount_; + size_t pendingTaskCountMax_; + + ThreadManager::STATE state_; + shared_ptr<ThreadFactory> threadFactory_; + + + friend class ThreadManager::Task; + std::queue<shared_ptr<Task> > tasks_; + Monitor monitor_; + Monitor workerMonitor_; + + friend class ThreadManager::Worker; + std::set<shared_ptr<Thread> > workers_; + std::set<shared_ptr<Thread> > deadWorkers_; + std::map<const Thread::id_t, shared_ptr<Thread> > idMap_; +}; + +class ThreadManager::Task : public Runnable { + + public: + enum STATE { + WAITING, + EXECUTING, + CANCELLED, + COMPLETE + }; + + Task(shared_ptr<Runnable> runnable) : + runnable_(runnable), + state_(WAITING) {} + + ~Task() {} + + void run() { + if (state_ == EXECUTING) { + runnable_->run(); + state_ = COMPLETE; + } + } + + private: + shared_ptr<Runnable> runnable_; + friend class ThreadManager::Worker; + STATE state_; +}; + +class ThreadManager::Worker: public Runnable { + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + public: + Worker(ThreadManager::Impl* manager) : + manager_(manager), + state_(UNINITIALIZED), + idle_(false) {} + + ~Worker() {} + + private: + bool isActive() const { + return + (manager_->workerCount_ <= manager_->workerMaxCount_) || + (manager_->state_ == JOINING && !manager_->tasks_.empty()); + } + + public: + /** + * Worker entry point + * + * As long as worker thread is running, pull tasks off the task queue and + * execute. + */ + void run() { + bool active = false; + bool notifyManager = false; + + /** + * Increment worker semaphore and notify manager if worker count reached + * desired max + * + * Note: We have to release the monitor and acquire the workerMonitor + * since that is what the manager blocks on for worker add/remove + */ + { + Synchronized s(manager_->monitor_); + active = manager_->workerCount_ < manager_->workerMaxCount_; + if (active) { + manager_->workerCount_++; + notifyManager = manager_->workerCount_ == manager_->workerMaxCount_; + } + } + + if (notifyManager) { + Synchronized s(manager_->workerMonitor_); + manager_->workerMonitor_.notify(); + notifyManager = false; + } + + while (active) { + shared_ptr<ThreadManager::Task> task; + + /** + * While holding manager monitor block for non-empty task queue (Also + * check that the thread hasn't been requested to stop). Once the queue + * is non-empty, dequeue a task, release monitor, and execute. If the + * worker max count has been decremented such that we exceed it, mark + * ourself inactive, decrement the worker count and notify the manager + * (technically we're notifying the next blocked thread but eventually + * the manager will see it. + */ + { + Synchronized s(manager_->monitor_); + active = isActive(); + + while (active && manager_->tasks_.empty()) { + manager_->idleCount_++; + idle_ = true; + manager_->monitor_.wait(); + active = isActive(); + idle_ = false; + manager_->idleCount_--; + } + + if (active) { + if (!manager_->tasks_.empty()) { + task = manager_->tasks_.front(); + manager_->tasks_.pop(); + if (task->state_ == ThreadManager::Task::WAITING) { + task->state_ = ThreadManager::Task::EXECUTING; + } + + /* If we have a pending task max and we just dropped below it, wakeup any + thread that might be blocked on add. */ + if (manager_->pendingTaskCountMax_ != 0 && + manager_->tasks_.size() == manager_->pendingTaskCountMax_ - 1) { + manager_->monitor_.notify(); + } + } + } else { + idle_ = true; + manager_->workerCount_--; + notifyManager = (manager_->workerCount_ == manager_->workerMaxCount_); + } + } + + if (task != NULL) { + if (task->state_ == ThreadManager::Task::EXECUTING) { + try { + task->run(); + } catch(...) { + // XXX need to log this + } + } + } + } + + { + Synchronized s(manager_->workerMonitor_); + manager_->deadWorkers_.insert(this->thread()); + if (notifyManager) { + manager_->workerMonitor_.notify(); + } + } + + return; + } + + private: + ThreadManager::Impl* manager_; + friend class ThreadManager::Impl; + STATE state_; + bool idle_; +}; + + + void ThreadManager::Impl::addWorker(size_t value) { + std::set<shared_ptr<Thread> > newThreads; + for (size_t ix = 0; ix < value; ix++) { + class ThreadManager::Worker; + shared_ptr<ThreadManager::Worker> worker = shared_ptr<ThreadManager::Worker>(new ThreadManager::Worker(this)); + newThreads.insert(threadFactory_->newThread(worker)); + } + + { + Synchronized s(monitor_); + workerMaxCount_ += value; + workers_.insert(newThreads.begin(), newThreads.end()); + } + + for (std::set<shared_ptr<Thread> >::iterator ix = newThreads.begin(); ix != newThreads.end(); ix++) { + shared_ptr<ThreadManager::Worker> worker = dynamic_pointer_cast<ThreadManager::Worker, Runnable>((*ix)->runnable()); + worker->state_ = ThreadManager::Worker::STARTING; + (*ix)->start(); + idMap_.insert(std::pair<const Thread::id_t, shared_ptr<Thread> >((*ix)->getId(), *ix)); + } + + { + Synchronized s(workerMonitor_); + while (workerCount_ != workerMaxCount_) { + workerMonitor_.wait(); + } + } +} + +void ThreadManager::Impl::start() { + + if (state_ == ThreadManager::STOPPED) { + return; + } + + { + Synchronized s(monitor_); + if (state_ == ThreadManager::UNINITIALIZED) { + if (threadFactory_ == NULL) { + throw InvalidArgumentException(); + } + state_ = ThreadManager::STARTED; + monitor_.notifyAll(); + } + + while (state_ == STARTING) { + monitor_.wait(); + } + } +} + +void ThreadManager::Impl::stopImpl(bool join) { + bool doStop = false; + if (state_ == ThreadManager::STOPPED) { + return; + } + + { + Synchronized s(monitor_); + if (state_ != ThreadManager::STOPPING && + state_ != ThreadManager::JOINING && + state_ != ThreadManager::STOPPED) { + doStop = true; + state_ = join ? ThreadManager::JOINING : ThreadManager::STOPPING; + } + } + + if (doStop) { + removeWorker(workerCount_); + } + + // XXX + // should be able to block here for transition to STOPPED since we're no + // using shared_ptrs + + { + Synchronized s(monitor_); + state_ = ThreadManager::STOPPED; + } + +} + +void ThreadManager::Impl::removeWorker(size_t value) { + std::set<shared_ptr<Thread> > removedThreads; + { + Synchronized s(monitor_); + if (value > workerMaxCount_) { + throw InvalidArgumentException(); + } + + workerMaxCount_ -= value; + + if (idleCount_ < value) { + for (size_t ix = 0; ix < idleCount_; ix++) { + monitor_.notify(); + } + } else { + monitor_.notifyAll(); + } + } + + { + Synchronized s(workerMonitor_); + + while (workerCount_ != workerMaxCount_) { + workerMonitor_.wait(); + } + + for (std::set<shared_ptr<Thread> >::iterator ix = deadWorkers_.begin(); ix != deadWorkers_.end(); ix++) { + workers_.erase(*ix); + idMap_.erase((*ix)->getId()); + } + + deadWorkers_.clear(); + } +} + + bool ThreadManager::Impl::canSleep() { + const Thread::id_t id = threadFactory_->getCurrentThreadId(); + return idMap_.find(id) == idMap_.end(); + } + + void ThreadManager::Impl::add(shared_ptr<Runnable> value, int64_t timeout) { + Synchronized s(monitor_); + + if (state_ != ThreadManager::STARTED) { + throw IllegalStateException(); + } + + if (pendingTaskCountMax_ > 0 && (tasks_.size() >= pendingTaskCountMax_)) { + if (canSleep() && timeout >= 0) { + while (pendingTaskCountMax_ > 0 && tasks_.size() >= pendingTaskCountMax_) { + monitor_.wait(timeout); + } + } else { + throw TooManyPendingTasksException(); + } + } + + tasks_.push(shared_ptr<ThreadManager::Task>(new ThreadManager::Task(value))); + + // If idle thread is available notify it, otherwise all worker threads are + // running and will get around to this task in time. + if (idleCount_ > 0) { + monitor_.notify(); + } + } + +void ThreadManager::Impl::remove(shared_ptr<Runnable> task) { + Synchronized s(monitor_); + if (state_ != ThreadManager::STARTED) { + throw IllegalStateException(); + } +} + +class SimpleThreadManager : public ThreadManager::Impl { + + public: + SimpleThreadManager(size_t workerCount=4, size_t pendingTaskCountMax=0) : + workerCount_(workerCount), + pendingTaskCountMax_(pendingTaskCountMax), + firstTime_(true) { + } + + void start() { + ThreadManager::Impl::pendingTaskCountMax(pendingTaskCountMax_); + ThreadManager::Impl::start(); + addWorker(workerCount_); + } + + private: + const size_t workerCount_; + const size_t pendingTaskCountMax_; + bool firstTime_; + Monitor monitor_; +}; + + +shared_ptr<ThreadManager> ThreadManager::newThreadManager() { + return shared_ptr<ThreadManager>(new ThreadManager::Impl()); +} + +shared_ptr<ThreadManager> ThreadManager::newSimpleThreadManager(size_t count, size_t pendingTaskCountMax) { + return shared_ptr<ThreadManager>(new SimpleThreadManager(count, pendingTaskCountMax)); +} + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/ThreadManager.h b/lib/cpp/src/concurrency/ThreadManager.h new file mode 100644 index 000000000..6e5a17817 --- /dev/null +++ b/lib/cpp/src/concurrency/ThreadManager.h @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_ +#define _THRIFT_CONCURRENCY_THREADMANAGER_H_ 1 + +#include <boost/shared_ptr.hpp> +#include <sys/types.h> +#include "Thread.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Thread Pool Manager and related classes + * + * @version $Id:$ + */ +class ThreadManager; + +/** + * ThreadManager class + * + * This class manages a pool of threads. It uses a ThreadFactory to create + * threads. It never actually creates or destroys worker threads, rather + * It maintains statistics on number of idle threads, number of active threads, + * task backlog, and average wait and service times and informs the PoolPolicy + * object bound to instances of this manager of interesting transitions. It is + * then up the PoolPolicy object to decide if the thread pool size needs to be + * adjusted and call this object addWorker and removeWorker methods to make + * changes. + * + * This design allows different policy implementations to used this code to + * handle basic worker thread management and worker task execution and focus on + * policy issues. The simplest policy, StaticPolicy, does nothing other than + * create a fixed number of threads. + */ +class ThreadManager { + + protected: + ThreadManager() {} + + public: + virtual ~ThreadManager() {} + + /** + * Starts the thread manager. Verifies all attributes have been properly + * initialized, then allocates necessary resources to begin operation + */ + virtual void start() = 0; + + /** + * Stops the thread manager. Aborts all remaining unprocessed task, shuts + * down all created worker threads, and realeases all allocated resources. + * This method blocks for all worker threads to complete, thus it can + * potentially block forever if a worker thread is running a task that + * won't terminate. + */ + virtual void stop() = 0; + + /** + * Joins the thread manager. This is the same as stop, except that it will + * block until all the workers have finished their work. At that point + * the ThreadManager will transition into the STOPPED state. + */ + virtual void join() = 0; + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + JOINING, + STOPPING, + STOPPED + }; + + virtual const STATE state() const = 0; + + virtual boost::shared_ptr<ThreadFactory> threadFactory() const = 0; + + virtual void threadFactory(boost::shared_ptr<ThreadFactory> value) = 0; + + virtual void addWorker(size_t value=1) = 0; + + virtual void removeWorker(size_t value=1) = 0; + + /** + * Gets the current number of idle worker threads + */ + virtual size_t idleWorkerCount() const = 0; + + /** + * Gets the current number of total worker threads + */ + virtual size_t workerCount() const = 0; + + /** + * Gets the current number of pending tasks + */ + virtual size_t pendingTaskCount() const = 0; + + /** + * Gets the current number of pending and executing tasks + */ + virtual size_t totalTaskCount() const = 0; + + /** + * Gets the maximum pending task count. 0 indicates no maximum + */ + virtual size_t pendingTaskCountMax() const = 0; + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * This method will block if pendingTaskCountMax() in not zero and pendingTaskCount() + * is greater than or equalt to pendingTaskCountMax(). If this method is called in the + * context of a ThreadManager worker thread it will throw a + * TooManyPendingTasksException + * + * @param task The task to queue for execution + * + * @param timeout Time to wait in milliseconds to add a task when a pending-task-count + * is specified. Specific cases: + * timeout = 0 : Wait forever to queue task. + * timeout = -1 : Return immediately if pending task count exceeds specified max + * + * @throws TooManyPendingTasksException Pending task count exceeds max pending task count + */ + virtual void add(boost::shared_ptr<Runnable>task, int64_t timeout=0LL) = 0; + + /** + * Removes a pending task + */ + virtual void remove(boost::shared_ptr<Runnable> task) = 0; + + static boost::shared_ptr<ThreadManager> newThreadManager(); + + /** + * Creates a simple thread manager the uses count number of worker threads and has + * a pendingTaskCountMax maximum pending tasks. The default, 0, specified no limit + * on pending tasks + */ + static boost::shared_ptr<ThreadManager> newSimpleThreadManager(size_t count=4, size_t pendingTaskCountMax=0); + + class Task; + + class Worker; + + class Impl; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_ diff --git a/lib/cpp/src/concurrency/TimerManager.cpp b/lib/cpp/src/concurrency/TimerManager.cpp new file mode 100644 index 000000000..25515dc82 --- /dev/null +++ b/lib/cpp/src/concurrency/TimerManager.cpp @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TimerManager.h" +#include "Exception.h" +#include "Util.h" + +#include <assert.h> +#include <iostream> +#include <set> + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; + +typedef std::multimap<int64_t, shared_ptr<TimerManager::Task> >::iterator task_iterator; +typedef std::pair<task_iterator, task_iterator> task_range; + +/** + * TimerManager class + * + * @version $Id:$ + */ +class TimerManager::Task : public Runnable { + + public: + enum STATE { + WAITING, + EXECUTING, + CANCELLED, + COMPLETE + }; + + Task(shared_ptr<Runnable> runnable) : + runnable_(runnable), + state_(WAITING) {} + + ~Task() { + } + + void run() { + if (state_ == EXECUTING) { + runnable_->run(); + state_ = COMPLETE; + } + } + + private: + shared_ptr<Runnable> runnable_; + class TimerManager::Dispatcher; + friend class TimerManager::Dispatcher; + STATE state_; +}; + +class TimerManager::Dispatcher: public Runnable { + + public: + Dispatcher(TimerManager* manager) : + manager_(manager) {} + + ~Dispatcher() {} + + /** + * Dispatcher entry point + * + * As long as dispatcher thread is running, pull tasks off the task taskMap_ + * and execute. + */ + void run() { + { + Synchronized s(manager_->monitor_); + if (manager_->state_ == TimerManager::STARTING) { + manager_->state_ = TimerManager::STARTED; + manager_->monitor_.notifyAll(); + } + } + + do { + std::set<shared_ptr<TimerManager::Task> > expiredTasks; + { + Synchronized s(manager_->monitor_); + task_iterator expiredTaskEnd; + int64_t now = Util::currentTime(); + while (manager_->state_ == TimerManager::STARTED && + (expiredTaskEnd = manager_->taskMap_.upper_bound(now)) == manager_->taskMap_.begin()) { + int64_t timeout = 0LL; + if (!manager_->taskMap_.empty()) { + timeout = manager_->taskMap_.begin()->first - now; + } + assert((timeout != 0 && manager_->taskCount_ > 0) || (timeout == 0 && manager_->taskCount_ == 0)); + try { + manager_->monitor_.wait(timeout); + } catch (TimedOutException &e) {} + now = Util::currentTime(); + } + + if (manager_->state_ == TimerManager::STARTED) { + for (task_iterator ix = manager_->taskMap_.begin(); ix != expiredTaskEnd; ix++) { + shared_ptr<TimerManager::Task> task = ix->second; + expiredTasks.insert(task); + if (task->state_ == TimerManager::Task::WAITING) { + task->state_ = TimerManager::Task::EXECUTING; + } + manager_->taskCount_--; + } + manager_->taskMap_.erase(manager_->taskMap_.begin(), expiredTaskEnd); + } + } + + for (std::set<shared_ptr<Task> >::iterator ix = expiredTasks.begin(); ix != expiredTasks.end(); ix++) { + (*ix)->run(); + } + + } while (manager_->state_ == TimerManager::STARTED); + + { + Synchronized s(manager_->monitor_); + if (manager_->state_ == TimerManager::STOPPING) { + manager_->state_ = TimerManager::STOPPED; + manager_->monitor_.notify(); + } + } + return; + } + + private: + TimerManager* manager_; + friend class TimerManager; +}; + +TimerManager::TimerManager() : + taskCount_(0), + state_(TimerManager::UNINITIALIZED), + dispatcher_(shared_ptr<Dispatcher>(new Dispatcher(this))) { +} + + +TimerManager::~TimerManager() { + + // If we haven't been explicitly stopped, do so now. We don't need to grab + // the monitor here, since stop already takes care of reentrancy. + + if (state_ != STOPPED) { + try { + stop(); + } catch(...) { + throw; + // uhoh + } + } +} + +void TimerManager::start() { + bool doStart = false; + { + Synchronized s(monitor_); + if (threadFactory_ == NULL) { + throw InvalidArgumentException(); + } + if (state_ == TimerManager::UNINITIALIZED) { + state_ = TimerManager::STARTING; + doStart = true; + } + } + + if (doStart) { + dispatcherThread_ = threadFactory_->newThread(dispatcher_); + dispatcherThread_->start(); + } + + { + Synchronized s(monitor_); + while (state_ == TimerManager::STARTING) { + monitor_.wait(); + } + assert(state_ != TimerManager::STARTING); + } +} + +void TimerManager::stop() { + bool doStop = false; + { + Synchronized s(monitor_); + if (state_ == TimerManager::UNINITIALIZED) { + state_ = TimerManager::STOPPED; + } else if (state_ != STOPPING && state_ != STOPPED) { + doStop = true; + state_ = STOPPING; + monitor_.notifyAll(); + } + while (state_ != STOPPED) { + monitor_.wait(); + } + } + + if (doStop) { + // Clean up any outstanding tasks + for (task_iterator ix = taskMap_.begin(); ix != taskMap_.end(); ix++) { + taskMap_.erase(ix); + } + + // Remove dispatcher's reference to us. + dispatcher_->manager_ = NULL; + } +} + +shared_ptr<const ThreadFactory> TimerManager::threadFactory() const { + Synchronized s(monitor_); + return threadFactory_; +} + +void TimerManager::threadFactory(shared_ptr<const ThreadFactory> value) { + Synchronized s(monitor_); + threadFactory_ = value; +} + +size_t TimerManager::taskCount() const { + return taskCount_; +} + +void TimerManager::add(shared_ptr<Runnable> task, int64_t timeout) { + int64_t now = Util::currentTime(); + timeout += now; + + { + Synchronized s(monitor_); + if (state_ != TimerManager::STARTED) { + throw IllegalStateException(); + } + + taskCount_++; + taskMap_.insert(std::pair<int64_t, shared_ptr<Task> >(timeout, shared_ptr<Task>(new Task(task)))); + + // If the task map was empty, or if we have an expiration that is earlier + // than any previously seen, kick the dispatcher so it can update its + // timeout + if (taskCount_ == 1 || timeout < taskMap_.begin()->first) { + monitor_.notify(); + } + } +} + +void TimerManager::add(shared_ptr<Runnable> task, const struct timespec& value) { + + int64_t expiration; + Util::toMilliseconds(expiration, value); + + int64_t now = Util::currentTime(); + + if (expiration < now) { + throw InvalidArgumentException(); + } + + add(task, expiration - now); +} + + +void TimerManager::remove(shared_ptr<Runnable> task) { + Synchronized s(monitor_); + if (state_ != TimerManager::STARTED) { + throw IllegalStateException(); + } +} + +const TimerManager::STATE TimerManager::state() const { return state_; } + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/TimerManager.h b/lib/cpp/src/concurrency/TimerManager.h new file mode 100644 index 000000000..f3f799f93 --- /dev/null +++ b/lib/cpp/src/concurrency/TimerManager.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_ +#define _THRIFT_CONCURRENCY_TIMERMANAGER_H_ 1 + +#include "Exception.h" +#include "Monitor.h" +#include "Thread.h" + +#include <boost/shared_ptr.hpp> +#include <map> +#include <time.h> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Timer Manager + * + * This class dispatches timer tasks when they fall due. + * + * @version $Id:$ + */ +class TimerManager { + + public: + + TimerManager(); + + virtual ~TimerManager(); + + virtual boost::shared_ptr<const ThreadFactory> threadFactory() const; + + virtual void threadFactory(boost::shared_ptr<const ThreadFactory> value); + + /** + * Starts the timer manager service + * + * @throws IllegalArgumentException Missing thread factory attribute + */ + virtual void start(); + + /** + * Stops the timer manager service + */ + virtual void stop(); + + virtual size_t taskCount() const ; + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * @param task The task to execute + * @param timeout Time in milliseconds to delay before executing task + */ + virtual void add(boost::shared_ptr<Runnable> task, int64_t timeout); + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * @param task The task to execute + * @param timeout Absolute time in the future to execute task. + */ + virtual void add(boost::shared_ptr<Runnable> task, const struct timespec& timeout); + + /** + * Removes a pending task + * + * @throws NoSuchTaskException Specified task doesn't exist. It was either + * processed already or this call was made for a + * task that was never added to this timer + * + * @throws UncancellableTaskException Specified task is already being + * executed or has completed execution. + */ + virtual void remove(boost::shared_ptr<Runnable> task); + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + virtual const STATE state() const; + + private: + boost::shared_ptr<const ThreadFactory> threadFactory_; + class Task; + friend class Task; + std::multimap<int64_t, boost::shared_ptr<Task> > taskMap_; + size_t taskCount_; + Monitor monitor_; + STATE state_; + class Dispatcher; + friend class Dispatcher; + boost::shared_ptr<Dispatcher> dispatcher_; + boost::shared_ptr<Thread> dispatcherThread_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_ diff --git a/lib/cpp/src/concurrency/Util.cpp b/lib/cpp/src/concurrency/Util.cpp new file mode 100644 index 000000000..1c4493716 --- /dev/null +++ b/lib/cpp/src/concurrency/Util.cpp @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Util.h" + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#if defined(HAVE_CLOCK_GETTIME) +#include <time.h> +#elif defined(HAVE_GETTIMEOFDAY) +#include <sys/time.h> +#endif // defined(HAVE_CLOCK_GETTIME) + +namespace apache { namespace thrift { namespace concurrency { + +const int64_t Util::currentTime() { + int64_t result; + +#if defined(HAVE_CLOCK_GETTIME) + struct timespec now; + int ret = clock_gettime(CLOCK_REALTIME, &now); + assert(ret == 0); + toMilliseconds(result, now); +#elif defined(HAVE_GETTIMEOFDAY) + struct timeval now; + int ret = gettimeofday(&now, NULL); + assert(ret == 0); + toMilliseconds(result, now); +#else +#error "No high-precision clock is available." +#endif // defined(HAVE_CLOCK_GETTIME) + + return result; +} + + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/Util.h b/lib/cpp/src/concurrency/Util.h new file mode 100644 index 000000000..25fcc2086 --- /dev/null +++ b/lib/cpp/src/concurrency/Util.h @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_UTIL_H_ +#define _THRIFT_CONCURRENCY_UTIL_H_ 1 + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> +#include <time.h> +#include <sys/time.h> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Utility methods + * + * This class contains basic utility methods for converting time formats, + * and other common platform-dependent concurrency operations. + * It should not be included in API headers for other concurrency library + * headers, since it will, by definition, pull in all sorts of horrid + * platform dependent crap. Rather it should be inluded directly in + * concurrency library implementation source. + * + * @version $Id:$ + */ +class Util { + + static const int64_t NS_PER_S = 1000000000LL; + static const int64_t US_PER_S = 1000000LL; + static const int64_t MS_PER_S = 1000LL; + + static const int64_t NS_PER_MS = NS_PER_S / MS_PER_S; + static const int64_t US_PER_MS = US_PER_S / MS_PER_S; + + public: + + /** + * Converts millisecond timestamp into a timespec struct + * + * @param struct timespec& result + * @param time or duration in milliseconds + */ + static void toTimespec(struct timespec& result, int64_t value) { + result.tv_sec = value / MS_PER_S; // ms to s + result.tv_nsec = (value % MS_PER_S) * NS_PER_MS; // ms to ns + } + + static void toTimeval(struct timeval& result, int64_t value) { + result.tv_sec = value / MS_PER_S; // ms to s + result.tv_usec = (value % MS_PER_S) * US_PER_MS; // ms to us + } + + /** + * Converts struct timespec to milliseconds + */ + static const void toMilliseconds(int64_t& result, const struct timespec& value) { + result = (value.tv_sec * MS_PER_S) + (value.tv_nsec / NS_PER_MS); + // round up -- int64_t cast is to avoid a compiler error for some GCCs + if (int64_t(value.tv_nsec) % NS_PER_MS >= (NS_PER_MS / 2)) { + ++result; + } + } + + /** + * Converts struct timeval to milliseconds + */ + static const void toMilliseconds(int64_t& result, const struct timeval& value) { + result = (value.tv_sec * MS_PER_S) + (value.tv_usec / US_PER_MS); + // round up -- int64_t cast is to avoid a compiler error for some GCCs + if (int64_t(value.tv_usec) % US_PER_MS >= (US_PER_MS / 2)) { + ++result; + } + } + + /** + * Get current time as milliseconds from epoch + */ + static const int64_t currentTime(); +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_UTIL_H_ diff --git a/lib/cpp/src/concurrency/test/Tests.cpp b/lib/cpp/src/concurrency/test/Tests.cpp new file mode 100644 index 000000000..c80bb883f --- /dev/null +++ b/lib/cpp/src/concurrency/test/Tests.cpp @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <iostream> +#include <vector> +#include <string> + +#include "ThreadFactoryTests.h" +#include "TimerManagerTests.h" +#include "ThreadManagerTests.h" + +int main(int argc, char** argv) { + + std::string arg; + + std::vector<std::string> args(argc - 1 > 1 ? argc - 1 : 1); + + args[0] = "all"; + + for (int ix = 1; ix < argc; ix++) { + args[ix - 1] = std::string(argv[ix]); + } + + bool runAll = args[0].compare("all") == 0; + + if (runAll || args[0].compare("thread-factory") == 0) { + + ThreadFactoryTests threadFactoryTests; + + std::cout << "ThreadFactory tests..." << std::endl; + + size_t count = 1000; + size_t floodLoops = 1; + size_t floodCount = 100000; + + std::cout << "\t\tThreadFactory reap N threads test: N = " << count << std::endl; + + assert(threadFactoryTests.reapNThreads(count)); + + std::cout << "\t\tThreadFactory floodN threads test: N = " << floodCount << std::endl; + + assert(threadFactoryTests.floodNTest(floodLoops, floodCount)); + + std::cout << "\t\tThreadFactory synchronous start test" << std::endl; + + assert(threadFactoryTests.synchStartTest()); + + std::cout << "\t\tThreadFactory monitor timeout test" << std::endl; + + assert(threadFactoryTests.monitorTimeoutTest()); + } + + if (runAll || args[0].compare("util") == 0) { + + std::cout << "Util tests..." << std::endl; + + std::cout << "\t\tUtil minimum time" << std::endl; + + int64_t time00 = Util::currentTime(); + int64_t time01 = Util::currentTime(); + + std::cout << "\t\t\tMinimum time: " << time01 - time00 << "ms" << std::endl; + + time00 = Util::currentTime(); + time01 = time00; + size_t count = 0; + + while (time01 < time00 + 10) { + count++; + time01 = Util::currentTime(); + } + + std::cout << "\t\t\tscall per ms: " << count / (time01 - time00) << std::endl; + } + + + if (runAll || args[0].compare("timer-manager") == 0) { + + std::cout << "TimerManager tests..." << std::endl; + + std::cout << "\t\tTimerManager test00" << std::endl; + + TimerManagerTests timerManagerTests; + + assert(timerManagerTests.test00()); + } + + if (runAll || args[0].compare("thread-manager") == 0) { + + std::cout << "ThreadManager tests..." << std::endl; + + { + + size_t workerCount = 100; + + size_t taskCount = 100000; + + int64_t delay = 10LL; + + std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl; + + ThreadManagerTests threadManagerTests; + + assert(threadManagerTests.loadTest(taskCount, delay, workerCount)); + + std::cout << "\t\tThreadManager block test: worker count: " << workerCount << " delay: " << delay << std::endl; + + assert(threadManagerTests.blockTest(delay, workerCount)); + + } + } + + if (runAll || args[0].compare("thread-manager-benchmark") == 0) { + + std::cout << "ThreadManager benchmark tests..." << std::endl; + + { + + size_t minWorkerCount = 2; + + size_t maxWorkerCount = 512; + + size_t tasksPerWorker = 1000; + + int64_t delay = 10LL; + + for (size_t workerCount = minWorkerCount; workerCount < maxWorkerCount; workerCount*= 2) { + + size_t taskCount = workerCount * tasksPerWorker; + + std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl; + + ThreadManagerTests threadManagerTests; + + threadManagerTests.loadTest(taskCount, delay, workerCount); + } + } + } +} diff --git a/lib/cpp/src/concurrency/test/ThreadFactoryTests.h b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h new file mode 100644 index 000000000..859fbaf51 --- /dev/null +++ b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <config.h> +#include <concurrency/Thread.h> +#include <concurrency/PosixThreadFactory.h> +#include <concurrency/Monitor.h> +#include <concurrency/Util.h> + +#include <assert.h> +#include <iostream> +#include <set> + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using boost::shared_ptr; +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class ThreadFactoryTests { + +public: + + static const double ERROR; + + class Task: public Runnable { + + public: + + Task() {} + + void run() { + std::cout << "\t\t\tHello World" << std::endl; + } + }; + + /** + * Hello world test + */ + bool helloWorldTest() { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + shared_ptr<Task> task = shared_ptr<Task>(new ThreadFactoryTests::Task()); + + shared_ptr<Thread> thread = threadFactory.newThread(task); + + thread->start(); + + thread->join(); + + std::cout << "\t\t\tSuccess!" << std::endl; + + return true; + } + + /** + * Reap N threads + */ + class ReapNTask: public Runnable { + + public: + + ReapNTask(Monitor& monitor, int& activeCount) : + _monitor(monitor), + _count(activeCount) {} + + void run() { + Synchronized s(_monitor); + + _count--; + + //std::cout << "\t\t\tthread count: " << _count << std::endl; + + if (_count == 0) { + _monitor.notify(); + } + } + + Monitor& _monitor; + + int& _count; + }; + + bool reapNThreads(int loop=1, int count=10) { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + Monitor* monitor = new Monitor(); + + for(int lix = 0; lix < loop; lix++) { + + int* activeCount = new int(count); + + std::set<shared_ptr<Thread> > threads; + + int tix; + + for (tix = 0; tix < count; tix++) { + try { + threads.insert(threadFactory.newThread(shared_ptr<Runnable>(new ReapNTask(*monitor, *activeCount)))); + } catch(SystemResourceException& e) { + std::cout << "\t\t\tfailed to create " << lix * count + tix << " thread " << e.what() << std::endl; + throw e; + } + } + + tix = 0; + for (std::set<shared_ptr<Thread> >::const_iterator thread = threads.begin(); thread != threads.end(); tix++, ++thread) { + + try { + (*thread)->start(); + } catch(SystemResourceException& e) { + std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl; + throw e; + } + } + + { + Synchronized s(*monitor); + while (*activeCount > 0) { + monitor->wait(1000); + } + } + + for (std::set<shared_ptr<Thread> >::const_iterator thread = threads.begin(); thread != threads.end(); thread++) { + threads.erase(*thread); + } + + std::cout << "\t\t\treaped " << lix * count << " threads" << std::endl; + } + + std::cout << "\t\t\tSuccess!" << std::endl; + + return true; + } + + class SynchStartTask: public Runnable { + + public: + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + SynchStartTask(Monitor& monitor, volatile STATE& state) : + _monitor(monitor), + _state(state) {} + + void run() { + { + Synchronized s(_monitor); + if (_state == SynchStartTask::STARTING) { + _state = SynchStartTask::STARTED; + _monitor.notify(); + } + } + + { + Synchronized s(_monitor); + while (_state == SynchStartTask::STARTED) { + _monitor.wait(); + } + + if (_state == SynchStartTask::STOPPING) { + _state = SynchStartTask::STOPPED; + _monitor.notifyAll(); + } + } + } + + private: + Monitor& _monitor; + volatile STATE& _state; + }; + + bool synchStartTest() { + + Monitor monitor; + + SynchStartTask::STATE state = SynchStartTask::UNINITIALIZED; + + shared_ptr<SynchStartTask> task = shared_ptr<SynchStartTask>(new SynchStartTask(monitor, state)); + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + shared_ptr<Thread> thread = threadFactory.newThread(task); + + if (state == SynchStartTask::UNINITIALIZED) { + + state = SynchStartTask::STARTING; + + thread->start(); + } + + { + Synchronized s(monitor); + while (state == SynchStartTask::STARTING) { + monitor.wait(); + } + } + + assert(state != SynchStartTask::STARTING); + + { + Synchronized s(monitor); + + try { + monitor.wait(100); + } catch(TimedOutException& e) { + } + + if (state == SynchStartTask::STARTED) { + + state = SynchStartTask::STOPPING; + + monitor.notify(); + } + + while (state == SynchStartTask::STOPPING) { + monitor.wait(); + } + } + + assert(state == SynchStartTask::STOPPED); + + bool success = true; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "!" << std::endl; + + return true; + } + + /** See how accurate monitor timeout is. */ + + bool monitorTimeoutTest(size_t count=1000, int64_t timeout=10) { + + Monitor monitor; + + int64_t startTime = Util::currentTime(); + + for (size_t ix = 0; ix < count; ix++) { + { + Synchronized s(monitor); + try { + monitor.wait(timeout); + } catch(TimedOutException& e) { + } + } + } + + int64_t endTime = Util::currentTime(); + + double error = ((endTime - startTime) - (count * timeout)) / (double)(count * timeout); + + if (error < 0.0) { + + error *= 1.0; + } + + bool success = error < ThreadFactoryTests::ERROR; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << count * timeout << "ms elapsed time: "<< endTime - startTime << "ms error%: " << error * 100.0 << std::endl; + + return success; + } + + + class FloodTask : public Runnable { + public: + + FloodTask(const size_t id) :_id(id) {} + ~FloodTask(){ + if(_id % 1000 == 0) { + std::cout << "\t\tthread " << _id << " done" << std::endl; + } + } + + void run(){ + if(_id % 1000 == 0) { + std::cout << "\t\tthread " << _id << " started" << std::endl; + } + + usleep(1); + } + const size_t _id; + }; + + void foo(PosixThreadFactory *tf) { + } + + bool floodNTest(size_t loop=1, size_t count=100000) { + + bool success = false; + + for(size_t lix = 0; lix < loop; lix++) { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + threadFactory.setDetached(true); + + for(size_t tix = 0; tix < count; tix++) { + + try { + + shared_ptr<FloodTask> task(new FloodTask(lix * count + tix )); + + shared_ptr<Thread> thread = threadFactory.newThread(task); + + thread->start(); + + usleep(1); + + } catch (TException& e) { + + std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl; + + return success; + } + } + + std::cout << "\t\t\tflooded " << (lix + 1) * count << " threads" << std::endl; + + success = true; + } + + return success; + } +}; + +const double ThreadFactoryTests::ERROR = .20; + +}}}} // apache::thrift::concurrency::test + diff --git a/lib/cpp/src/concurrency/test/ThreadManagerTests.h b/lib/cpp/src/concurrency/test/ThreadManagerTests.h new file mode 100644 index 000000000..e7b517431 --- /dev/null +++ b/lib/cpp/src/concurrency/test/ThreadManagerTests.h @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <config.h> +#include <concurrency/ThreadManager.h> +#include <concurrency/PosixThreadFactory.h> +#include <concurrency/Monitor.h> +#include <concurrency/Util.h> + +#include <assert.h> +#include <set> +#include <iostream> +#include <set> +#include <stdint.h> + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class ThreadManagerTests { + +public: + + static const double ERROR; + + class Task: public Runnable { + + public: + + Task(Monitor& monitor, size_t& count, int64_t timeout) : + _monitor(monitor), + _count(count), + _timeout(timeout), + _done(false) {} + + void run() { + + _startTime = Util::currentTime(); + + { + Synchronized s(_sleep); + + try { + _sleep.wait(_timeout); + } catch(TimedOutException& e) { + ; + }catch(...) { + assert(0); + } + } + + _endTime = Util::currentTime(); + + _done = true; + + { + Synchronized s(_monitor); + + // std::cout << "Thread " << _count << " completed " << std::endl; + + _count--; + + if (_count == 0) { + + _monitor.notify(); + } + } + } + + Monitor& _monitor; + size_t& _count; + int64_t _timeout; + int64_t _startTime; + int64_t _endTime; + bool _done; + Monitor _sleep; + }; + + /** + * Dispatch count tasks, each of which blocks for timeout milliseconds then + * completes. Verify that all tasks completed and that thread manager cleans + * up properly on delete. + */ + bool loadTest(size_t count=100, int64_t timeout=100LL, size_t workerCount=4) { + + Monitor monitor; + + size_t activeCount = count; + + shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(workerCount); + + shared_ptr<PosixThreadFactory> threadFactory = shared_ptr<PosixThreadFactory>(new PosixThreadFactory()); + + threadFactory->setPriority(PosixThreadFactory::HIGHEST); + + threadManager->threadFactory(threadFactory); + + threadManager->start(); + + std::set<shared_ptr<ThreadManagerTests::Task> > tasks; + + for (size_t ix = 0; ix < count; ix++) { + + tasks.insert(shared_ptr<ThreadManagerTests::Task>(new ThreadManagerTests::Task(monitor, activeCount, timeout))); + } + + int64_t time00 = Util::currentTime(); + + for (std::set<shared_ptr<ThreadManagerTests::Task> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + + threadManager->add(*ix); + } + + { + Synchronized s(monitor); + + while(activeCount > 0) { + + monitor.wait(); + } + } + + int64_t time01 = Util::currentTime(); + + int64_t firstTime = 9223372036854775807LL; + int64_t lastTime = 0; + + double averageTime = 0; + int64_t minTime = 9223372036854775807LL; + int64_t maxTime = 0; + + for (std::set<shared_ptr<ThreadManagerTests::Task> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + + shared_ptr<ThreadManagerTests::Task> task = *ix; + + int64_t delta = task->_endTime - task->_startTime; + + assert(delta > 0); + + if (task->_startTime < firstTime) { + firstTime = task->_startTime; + } + + if (task->_endTime > lastTime) { + lastTime = task->_endTime; + } + + if (delta < minTime) { + minTime = delta; + } + + if (delta > maxTime) { + maxTime = delta; + } + + averageTime+= delta; + } + + averageTime /= count; + + std::cout << "\t\t\tfirst start: " << firstTime << "ms Last end: " << lastTime << "ms min: " << minTime << "ms max: " << maxTime << "ms average: " << averageTime << "ms" << std::endl; + + double expectedTime = ((count + (workerCount - 1)) / workerCount) * timeout; + + double error = ((time01 - time00) - expectedTime) / expectedTime; + + if (error < 0) { + error*= -1.0; + } + + bool success = error < ERROR; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << expectedTime << "ms elapsed time: "<< time01 - time00 << "ms error%: " << error * 100.0 << std::endl; + + return success; + } + + class BlockTask: public Runnable { + + public: + + BlockTask(Monitor& monitor, Monitor& bmonitor, size_t& count) : + _monitor(monitor), + _bmonitor(bmonitor), + _count(count) {} + + void run() { + { + Synchronized s(_bmonitor); + + _bmonitor.wait(); + + } + + { + Synchronized s(_monitor); + + _count--; + + if (_count == 0) { + + _monitor.notify(); + } + } + } + + Monitor& _monitor; + Monitor& _bmonitor; + size_t& _count; + }; + + /** + * Block test. Create pendingTaskCountMax tasks. Verify that we block adding the + * pendingTaskCountMax + 1th task. Verify that we unblock when a task completes */ + + bool blockTest(int64_t timeout=100LL, size_t workerCount=2) { + + bool success = false; + + try { + + Monitor bmonitor; + Monitor monitor; + + size_t pendingTaskMaxCount = workerCount; + + size_t activeCounts[] = {workerCount, pendingTaskMaxCount, 1}; + + shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(workerCount, pendingTaskMaxCount); + + shared_ptr<PosixThreadFactory> threadFactory = shared_ptr<PosixThreadFactory>(new PosixThreadFactory()); + + threadFactory->setPriority(PosixThreadFactory::HIGHEST); + + threadManager->threadFactory(threadFactory); + + threadManager->start(); + + std::set<shared_ptr<ThreadManagerTests::BlockTask> > tasks; + + for (size_t ix = 0; ix < workerCount; ix++) { + + tasks.insert(shared_ptr<ThreadManagerTests::BlockTask>(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[0]))); + } + + for (size_t ix = 0; ix < pendingTaskMaxCount; ix++) { + + tasks.insert(shared_ptr<ThreadManagerTests::BlockTask>(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[1]))); + } + + for (std::set<shared_ptr<ThreadManagerTests::BlockTask> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + threadManager->add(*ix); + } + + if(!(success = (threadManager->totalTaskCount() == pendingTaskMaxCount + workerCount))) { + throw TException("Unexpected pending task count"); + } + + shared_ptr<ThreadManagerTests::BlockTask> extraTask(new ThreadManagerTests::BlockTask(monitor, bmonitor, activeCounts[2])); + + try { + threadManager->add(extraTask, 1); + throw TException("Unexpected success adding task in excess of pending task count"); + } catch(TimedOutException& e) { + } + + std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl; + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[0] != 0) { + monitor.wait(); + } + } + + std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl; + + try { + threadManager->add(extraTask, 1); + } catch(TimedOutException& e) { + std::cout << "\t\t\t" << "add timed out unexpectedly" << std::endl; + throw TException("Unexpected timeout adding task"); + + } catch(TooManyPendingTasksException& e) { + std::cout << "\t\t\t" << "add encountered too many pending exepctions" << std::endl; + throw TException("Unexpected timeout adding task"); + } + + // Wake up tasks that were pending before and wait for them to complete + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[1] != 0) { + monitor.wait(); + } + } + + // Wake up the extra task and wait for it to complete + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[2] != 0) { + monitor.wait(); + } + } + + if(!(success = (threadManager->totalTaskCount() == 0))) { + throw TException("Unexpected pending task count"); + } + + } catch(TException& e) { + } + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << std::endl; + return success; + } +}; + +const double ThreadManagerTests::ERROR = .20; + +}}}} // apache::thrift::concurrency + +using namespace apache::thrift::concurrency::test; + diff --git a/lib/cpp/src/concurrency/test/TimerManagerTests.h b/lib/cpp/src/concurrency/test/TimerManagerTests.h new file mode 100644 index 000000000..e6fe6ce7e --- /dev/null +++ b/lib/cpp/src/concurrency/test/TimerManagerTests.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <concurrency/TimerManager.h> +#include <concurrency/PosixThreadFactory.h> +#include <concurrency/Monitor.h> +#include <concurrency/Util.h> + +#include <assert.h> +#include <iostream> + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class TimerManagerTests { + + public: + + static const double ERROR; + + class Task: public Runnable { + public: + + Task(Monitor& monitor, int64_t timeout) : + _timeout(timeout), + _startTime(Util::currentTime()), + _monitor(monitor), + _success(false), + _done(false) {} + + ~Task() { std::cerr << this << std::endl; } + + void run() { + + _endTime = Util::currentTime(); + + // Figure out error percentage + + int64_t delta = _endTime - _startTime; + + + delta = delta > _timeout ? delta - _timeout : _timeout - delta; + + float error = delta / _timeout; + + if(error < ERROR) { + _success = true; + } + + _done = true; + + std::cout << "\t\t\tTimerManagerTests::Task[" << this << "] done" << std::endl; //debug + + {Synchronized s(_monitor); + _monitor.notifyAll(); + } + } + + int64_t _timeout; + int64_t _startTime; + int64_t _endTime; + Monitor& _monitor; + bool _success; + bool _done; + }; + + /** + * This test creates two tasks and waits for the first to expire within 10% + * of the expected expiration time. It then verifies that the timer manager + * properly clean up itself and the remaining orphaned timeout task when the + * manager goes out of scope and its destructor is called. + */ + bool test00(int64_t timeout=1000LL) { + + shared_ptr<TimerManagerTests::Task> orphanTask = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, 10 * timeout)); + + { + + TimerManager timerManager; + + timerManager.threadFactory(shared_ptr<PosixThreadFactory>(new PosixThreadFactory())); + + timerManager.start(); + + assert(timerManager.state() == TimerManager::STARTED); + + shared_ptr<TimerManagerTests::Task> task = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, timeout)); + + { + Synchronized s(_monitor); + + timerManager.add(orphanTask, 10 * timeout); + + timerManager.add(task, timeout); + + _monitor.wait(); + } + + assert(task->_done); + + + std::cout << "\t\t\t" << (task->_success ? "Success" : "Failure") << "!" << std::endl; + } + + // timerManager.stop(); This is where it happens via destructor + + assert(!orphanTask->_done); + + return true; + } + + friend class TestTask; + + Monitor _monitor; +}; + +const double TimerManagerTests::ERROR = .20; + +}}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/processor/PeekProcessor.cpp b/lib/cpp/src/processor/PeekProcessor.cpp new file mode 100644 index 000000000..c721861bc --- /dev/null +++ b/lib/cpp/src/processor/PeekProcessor.cpp @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "PeekProcessor.h" + +using namespace apache::thrift::transport; +using namespace apache::thrift::protocol; +using namespace apache::thrift; + +namespace apache { namespace thrift { namespace processor { + +PeekProcessor::PeekProcessor() { + memoryBuffer_.reset(new TMemoryBuffer()); + targetTransport_ = memoryBuffer_; +} +PeekProcessor::~PeekProcessor() {} + +void PeekProcessor::initialize(boost::shared_ptr<TProcessor> actualProcessor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<TPipedTransportFactory> transportFactory) { + actualProcessor_ = actualProcessor; + pipedProtocol_ = protocolFactory->getProtocol(targetTransport_); + transportFactory_ = transportFactory; + transportFactory_->initializeTargetTransport(targetTransport_); +} + +boost::shared_ptr<TTransport> PeekProcessor::getPipedTransport(boost::shared_ptr<TTransport> in) { + return transportFactory_->getTransport(in); +} + +void PeekProcessor::setTargetTransport(boost::shared_ptr<TTransport> targetTransport) { + targetTransport_ = targetTransport; + if (boost::dynamic_pointer_cast<TMemoryBuffer>(targetTransport_)) { + memoryBuffer_ = boost::dynamic_pointer_cast<TMemoryBuffer>(targetTransport); + } else if (boost::dynamic_pointer_cast<TPipedTransport>(targetTransport_)) { + memoryBuffer_ = boost::dynamic_pointer_cast<TMemoryBuffer>(boost::dynamic_pointer_cast<TPipedTransport>(targetTransport_)->getTargetTransport()); + } + + if (!memoryBuffer_) { + throw TException("Target transport must be a TMemoryBuffer or a TPipedTransport with TMemoryBuffer"); + } +} + +bool PeekProcessor::process(boost::shared_ptr<TProtocol> in, + boost::shared_ptr<TProtocol> out) { + + std::string fname; + TMessageType mtype; + int32_t seqid; + in->readMessageBegin(fname, mtype, seqid); + + if (mtype != T_CALL) { + throw TException("Unexpected message type"); + } + + // Peek at the name + peekName(fname); + + TType ftype; + int16_t fid; + while (true) { + in->readFieldBegin(fname, ftype, fid); + if (ftype == T_STOP) { + break; + } + + // Peek at the variable + peek(in, ftype, fid); + in->readFieldEnd(); + } + in->readMessageEnd(); + in->getTransport()->readEnd(); + + // + // All the data is now in memoryBuffer_ and ready to be processed + // + + // Let's first take a peek at the full data in memory + uint8_t* buffer; + uint32_t size; + memoryBuffer_->getBuffer(&buffer, &size); + peekBuffer(buffer, size); + + // Done peeking at variables + peekEnd(); + + bool ret = actualProcessor_->process(pipedProtocol_, out); + memoryBuffer_->resetBuffer(); + return ret; +} + +void PeekProcessor::peekName(const std::string& fname) { +} + +void PeekProcessor::peekBuffer(uint8_t* buffer, uint32_t size) { +} + +void PeekProcessor::peek(boost::shared_ptr<TProtocol> in, + TType ftype, + int16_t fid) { + in->skip(ftype); +} + +void PeekProcessor::peekEnd() {} + +}}} diff --git a/lib/cpp/src/processor/PeekProcessor.h b/lib/cpp/src/processor/PeekProcessor.h new file mode 100644 index 000000000..0f7c016a0 --- /dev/null +++ b/lib/cpp/src/processor/PeekProcessor.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef PEEKPROCESSOR_H +#define PEEKPROCESSOR_H + +#include <string> +#include <TProcessor.h> +#include <transport/TTransport.h> +#include <transport/TTransportUtils.h> +#include <transport/TBufferTransports.h> +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace processor { + +/* + * Class for peeking at the raw data that is being processed by another processor + * and gives the derived class a chance to change behavior accordingly + * + */ +class PeekProcessor : public apache::thrift::TProcessor { + + public: + PeekProcessor(); + virtual ~PeekProcessor(); + + // Input here: actualProcessor - the underlying processor + // protocolFactory - the protocol factory used to wrap the memory buffer + // transportFactory - this TPipedTransportFactory is used to wrap the source transport + // via a call to getPipedTransport + void initialize(boost::shared_ptr<apache::thrift::TProcessor> actualProcessor, + boost::shared_ptr<apache::thrift::protocol::TProtocolFactory> protocolFactory, + boost::shared_ptr<apache::thrift::transport::TPipedTransportFactory> transportFactory); + + boost::shared_ptr<apache::thrift::transport::TTransport> getPipedTransport(boost::shared_ptr<apache::thrift::transport::TTransport> in); + + void setTargetTransport(boost::shared_ptr<apache::thrift::transport::TTransport> targetTransport); + + virtual bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> in, + boost::shared_ptr<apache::thrift::protocol::TProtocol> out); + + // The following three functions can be overloaded by child classes to + // achieve desired peeking behavior + virtual void peekName(const std::string& fname); + virtual void peekBuffer(uint8_t* buffer, uint32_t size); + virtual void peek(boost::shared_ptr<apache::thrift::protocol::TProtocol> in, + apache::thrift::protocol::TType ftype, + int16_t fid); + virtual void peekEnd(); + + private: + boost::shared_ptr<apache::thrift::TProcessor> actualProcessor_; + boost::shared_ptr<apache::thrift::protocol::TProtocol> pipedProtocol_; + boost::shared_ptr<apache::thrift::transport::TPipedTransportFactory> transportFactory_; + boost::shared_ptr<apache::thrift::transport::TMemoryBuffer> memoryBuffer_; + boost::shared_ptr<apache::thrift::transport::TTransport> targetTransport_; +}; + +}}} // apache::thrift::processor + +#endif diff --git a/lib/cpp/src/processor/StatsProcessor.h b/lib/cpp/src/processor/StatsProcessor.h new file mode 100644 index 000000000..820b3ad4b --- /dev/null +++ b/lib/cpp/src/processor/StatsProcessor.h @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef STATSPROCESSOR_H +#define STATSPROCESSOR_H + +#include <boost/shared_ptr.hpp> +#include <transport/TTransport.h> +#include <protocol/TProtocol.h> +#include <TProcessor.h> + +namespace apache { namespace thrift { namespace processor { + +/* + * Class for keeping track of function call statistics and printing them if desired + * + */ +class StatsProcessor : public apache::thrift::TProcessor { +public: + StatsProcessor(bool print, bool frequency) + : print_(print), + frequency_(frequency) + {} + virtual ~StatsProcessor() {}; + + virtual bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot) { + + piprot_ = piprot; + + std::string fname; + apache::thrift::protocol::TMessageType mtype; + int32_t seqid; + + piprot_->readMessageBegin(fname, mtype, seqid); + if (mtype != apache::thrift::protocol::T_CALL) { + if (print_) { + printf("Unknown message type\n"); + } + throw apache::thrift::TException("Unexpected message type"); + } + if (print_) { + printf("%s (", fname.c_str()); + } + if (frequency_) { + if (frequency_map_.find(fname) != frequency_map_.end()) { + frequency_map_[fname]++; + } else { + frequency_map_[fname] = 1; + } + } + + apache::thrift::protocol::TType ftype; + int16_t fid; + + while (true) { + piprot_->readFieldBegin(fname, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + + printAndPassToBuffer(ftype); + if (print_) { + printf(", "); + } + } + + if (print_) { + printf("\b\b)\n"); + } + return true; + } + + const std::map<std::string, int64_t>& get_frequency_map() { + return frequency_map_; + } + +protected: + void printAndPassToBuffer(apache::thrift::protocol::TType ftype) { + switch (ftype) { + case apache::thrift::protocol::T_BOOL: + { + bool boolv; + piprot_->readBool(boolv); + if (print_) { + printf("%d", boolv); + } + } + break; + case apache::thrift::protocol::T_BYTE: + { + int8_t bytev; + piprot_->readByte(bytev); + if (print_) { + printf("%d", bytev); + } + } + break; + case apache::thrift::protocol::T_I16: + { + int16_t i16; + piprot_->readI16(i16); + if (print_) { + printf("%d", i16); + } + } + break; + case apache::thrift::protocol::T_I32: + { + int32_t i32; + piprot_->readI32(i32); + if (print_) { + printf("%d", i32); + } + } + break; + case apache::thrift::protocol::T_I64: + { + int64_t i64; + piprot_->readI64(i64); + if (print_) { + printf("%ld", i64); + } + } + break; + case apache::thrift::protocol::T_DOUBLE: + { + double dub; + piprot_->readDouble(dub); + if (print_) { + printf("%f", dub); + } + } + break; + case apache::thrift::protocol::T_STRING: + { + std::string str; + piprot_->readString(str); + if (print_) { + printf("%s", str.c_str()); + } + } + break; + case apache::thrift::protocol::T_STRUCT: + { + std::string name; + int16_t fid; + apache::thrift::protocol::TType ftype; + piprot_->readStructBegin(name); + if (print_) { + printf("<"); + } + while (true) { + piprot_->readFieldBegin(name, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + printAndPassToBuffer(ftype); + if (print_) { + printf(","); + } + piprot_->readFieldEnd(); + } + piprot_->readStructEnd(); + if (print_) { + printf("\b>"); + } + } + break; + case apache::thrift::protocol::T_MAP: + { + apache::thrift::protocol::TType keyType; + apache::thrift::protocol::TType valType; + uint32_t i, size; + piprot_->readMapBegin(keyType, valType, size); + if (print_) { + printf("{"); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(keyType); + if (print_) { + printf("=>"); + } + printAndPassToBuffer(valType); + if (print_) { + printf(","); + } + } + piprot_->readMapEnd(); + if (print_) { + printf("\b}"); + } + } + break; + case apache::thrift::protocol::T_SET: + { + apache::thrift::protocol::TType elemType; + uint32_t i, size; + piprot_->readSetBegin(elemType, size); + if (print_) { + printf("{"); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(elemType); + if (print_) { + printf(","); + } + } + piprot_->readSetEnd(); + if (print_) { + printf("\b}"); + } + } + break; + case apache::thrift::protocol::T_LIST: + { + apache::thrift::protocol::TType elemType; + uint32_t i, size; + piprot_->readListBegin(elemType, size); + if (print_) { + printf("["); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(elemType); + if (print_) { + printf(","); + } + } + piprot_->readListEnd(); + if (print_) { + printf("\b]"); + } + } + break; + default: + break; + } + } + + boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot_; + std::map<std::string, int64_t> frequency_map_; + + bool print_; + bool frequency_; +}; + +}}} // apache::thrift::processor + +#endif diff --git a/lib/cpp/src/protocol/TBase64Utils.cpp b/lib/cpp/src/protocol/TBase64Utils.cpp new file mode 100644 index 000000000..14481c49c --- /dev/null +++ b/lib/cpp/src/protocol/TBase64Utils.cpp @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TBase64Utils.h" + +#include <boost/static_assert.hpp> + +using std::string; + +namespace apache { namespace thrift { namespace protocol { + + +static const uint8_t *kBase64EncodeTable = (const uint8_t *) + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf) { + buf[0] = kBase64EncodeTable[(in[0] >> 2) & 0x3F]; + if (len == 3) { + buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f]; + buf[2] = kBase64EncodeTable[((in[1] << 2) + (in[2] >> 6)) & 0x3f]; + buf[3] = kBase64EncodeTable[in[2] & 0x3f]; + } else if (len == 2) { + buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f]; + buf[2] = kBase64EncodeTable[(in[1] << 2) & 0x3f]; + } else { // len == 1 + buf[1] = kBase64EncodeTable[(in[0] << 4) & 0x3f]; + } +} + +static const uint8_t kBase64DecodeTable[256] ={ + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,62,-1,-1,-1,63, + 52,53,54,55,56,57,58,59,60,61,-1,-1,-1,-1,-1,-1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14, + 15,16,17,18,19,20,21,22,23,24,25,-1,-1,-1,-1,-1, + -1,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40, + 41,42,43,44,45,46,47,48,49,50,51,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, +}; + +void base64_decode(uint8_t *buf, uint32_t len) { + buf[0] = (kBase64DecodeTable[buf[0]] << 2) | + (kBase64DecodeTable[buf[1]] >> 4); + if (len > 2) { + buf[1] = ((kBase64DecodeTable[buf[1]] << 4) & 0xf0) | + (kBase64DecodeTable[buf[2]] >> 2); + if (len > 3) { + buf[2] = ((kBase64DecodeTable[buf[2]] << 6) & 0xc0) | + (kBase64DecodeTable[buf[3]]); + } + } +} + + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TBase64Utils.h b/lib/cpp/src/protocol/TBase64Utils.h new file mode 100644 index 000000000..3def73350 --- /dev/null +++ b/lib/cpp/src/protocol/TBase64Utils.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBASE64UTILS_H_ +#define _THRIFT_PROTOCOL_TBASE64UTILS_H_ + +#include <stdint.h> +#include <string> + +namespace apache { namespace thrift { namespace protocol { + +// in must be at least len bytes +// len must be 1, 2, or 3 +// buf must be a buffer of at least 4 bytes and may not overlap in +// the data is not padded with '='; the caller can do this if desired +void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf); + +// buf must be a buffer of at least 4 bytes and contain base64 encoded values +// buf will be changed to contain output bytes +// len is number of bytes to consume from input (must be 2, 3, or 4) +// no '=' padding should be included in the input +void base64_decode(uint8_t *buf, uint32_t len); + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TBASE64UTILS_H_ diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp new file mode 100644 index 000000000..6a4838b44 --- /dev/null +++ b/lib/cpp/src/protocol/TBinaryProtocol.cpp @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TBinaryProtocol.h" + +#include <limits> + +using std::string; + +namespace apache { namespace thrift { namespace protocol { + +uint32_t TBinaryProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + if (strict_write_) { + int32_t version = (VERSION_1) | ((int32_t)messageType); + uint32_t wsize = 0; + wsize += writeI32(version); + wsize += writeString(name); + wsize += writeI32(seqid); + return wsize; + } else { + uint32_t wsize = 0; + wsize += writeString(name); + wsize += writeByte((int8_t)messageType); + wsize += writeI32(seqid); + return wsize; + } +} + +uint32_t TBinaryProtocol::writeMessageEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeStructBegin(const char* name) { + return 0; +} + +uint32_t TBinaryProtocol::writeStructEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)fieldType); + wsize += writeI16(fieldId); + return wsize; +} + +uint32_t TBinaryProtocol::writeFieldEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeFieldStop() { + return + writeByte((int8_t)T_STOP); +} + +uint32_t TBinaryProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)keyType); + wsize += writeByte((int8_t)valType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeMapEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t) elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeListEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeSetEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeBool(const bool value) { + uint8_t tmp = value ? 1 : 0; + trans_->write(&tmp, 1); + return 1; +} + +uint32_t TBinaryProtocol::writeByte(const int8_t byte) { + trans_->write((uint8_t*)&byte, 1); + return 1; +} + +uint32_t TBinaryProtocol::writeI16(const int16_t i16) { + int16_t net = (int16_t)htons(i16); + trans_->write((uint8_t*)&net, 2); + return 2; +} + +uint32_t TBinaryProtocol::writeI32(const int32_t i32) { + int32_t net = (int32_t)htonl(i32); + trans_->write((uint8_t*)&net, 4); + return 4; +} + +uint32_t TBinaryProtocol::writeI64(const int64_t i64) { + int64_t net = (int64_t)htonll(i64); + trans_->write((uint8_t*)&net, 8); + return 8; +} + +uint32_t TBinaryProtocol::writeDouble(const double dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits = bitwise_cast<uint64_t>(dub); + bits = htonll(bits); + trans_->write((uint8_t*)&bits, 8); + return 8; +} + + +uint32_t TBinaryProtocol::writeString(const string& str) { + uint32_t size = str.size(); + uint32_t result = writeI32((int32_t)size); + if (size > 0) { + trans_->write((uint8_t*)str.data(), size); + } + return result + size; +} + +uint32_t TBinaryProtocol::writeBinary(const string& str) { + return TBinaryProtocol::writeString(str); +} + +/** + * Reading functions + */ + +uint32_t TBinaryProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t result = 0; + int32_t sz; + result += readI32(sz); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_1) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); + } + messageType = (TMessageType)(sz & 0x000000ff); + result += readString(name); + result += readI32(seqid); + } else { + if (strict_read_) { + throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); + } else { + // Handle pre-versioned input + int8_t type; + result += readStringBody(name, sz); + result += readByte(type); + messageType = (TMessageType)type; + result += readI32(seqid); + } + } + return result; +} + +uint32_t TBinaryProtocol::readMessageEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readStructBegin(string& name) { + name = ""; + return 0; +} + +uint32_t TBinaryProtocol::readStructEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readFieldBegin(string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t result = 0; + int8_t type; + result += readByte(type); + fieldType = (TType)type; + if (fieldType == T_STOP) { + fieldId = 0; + return result; + } + result += readI16(fieldId); + return result; +} + +uint32_t TBinaryProtocol::readFieldEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + int8_t k, v; + uint32_t result = 0; + int32_t sizei; + result += readByte(k); + keyType = (TType)k; + result += readByte(v); + valType = (TType)v; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readMapEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readListBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readListEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readSetEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readBool(bool& value) { + uint8_t b[1]; + trans_->readAll(b, 1); + value = *(int8_t*)b != 0; + return 1; +} + +uint32_t TBinaryProtocol::readByte(int8_t& byte) { + uint8_t b[1]; + trans_->readAll(b, 1); + byte = *(int8_t*)b; + return 1; +} + +uint32_t TBinaryProtocol::readI16(int16_t& i16) { + uint8_t b[2]; + trans_->readAll(b, 2); + i16 = *(int16_t*)b; + i16 = (int16_t)ntohs(i16); + return 2; +} + +uint32_t TBinaryProtocol::readI32(int32_t& i32) { + uint8_t b[4]; + trans_->readAll(b, 4); + i32 = *(int32_t*)b; + i32 = (int32_t)ntohl(i32); + return 4; +} + +uint32_t TBinaryProtocol::readI64(int64_t& i64) { + uint8_t b[8]; + trans_->readAll(b, 8); + i64 = *(int64_t*)b; + i64 = (int64_t)ntohll(i64); + return 8; +} + +uint32_t TBinaryProtocol::readDouble(double& dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits; + uint8_t b[8]; + trans_->readAll(b, 8); + bits = *(uint64_t*)b; + bits = ntohll(bits); + dub = bitwise_cast<double>(bits); + return 8; +} + +uint32_t TBinaryProtocol::readString(string& str) { + uint32_t result; + int32_t size; + result = readI32(size); + return result + readStringBody(str, size); +} + +uint32_t TBinaryProtocol::readBinary(string& str) { + return TBinaryProtocol::readString(str); +} + +uint32_t TBinaryProtocol::readStringBody(string& str, int32_t size) { + uint32_t result = 0; + + // Catch error cases + if (size < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } + if (string_limit_ > 0 && size > string_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + // Catch empty string case + if (size == 0) { + str = ""; + return result; + } + + // Use the heap here to prevent stack overflow for v. large strings + if (size > string_buf_size_ || string_buf_ == NULL) { + void* new_string_buf = std::realloc(string_buf_, (uint32_t)size); + if (new_string_buf == NULL) { + throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TBinaryProtocol::readString"); + } + string_buf_ = (uint8_t*)new_string_buf; + string_buf_size_ = size; + } + trans_->readAll(string_buf_, size); + str = string((char*)string_buf_, size); + return (uint32_t)size; +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h new file mode 100644 index 000000000..7fd3de673 --- /dev/null +++ b/lib/cpp/src/protocol/TBinaryProtocol.h @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace protocol { + +/** + * The default binary protocol for thrift. Writes all data in a very basic + * binary format, essentially just spitting out the raw bytes. + * + */ +class TBinaryProtocol : public TProtocol { + protected: + static const int32_t VERSION_MASK = 0xffff0000; + static const int32_t VERSION_1 = 0x80010000; + // VERSION_2 (0x80020000) is taken by TDenseProtocol. + + public: + TBinaryProtocol(boost::shared_ptr<TTransport> trans) : + TProtocol(trans), + string_limit_(0), + container_limit_(0), + strict_read_(false), + strict_write_(true), + string_buf_(NULL), + string_buf_size_(0) {} + + TBinaryProtocol(boost::shared_ptr<TTransport> trans, + int32_t string_limit, + int32_t container_limit, + bool strict_read, + bool strict_write) : + TProtocol(trans), + string_limit_(string_limit), + container_limit_(container_limit), + strict_read_(strict_read), + strict_write_(strict_write), + string_buf_(NULL), + string_buf_size_(0) {} + + ~TBinaryProtocol() { + if (string_buf_ != NULL) { + std::free(string_buf_); + string_buf_size_ = 0; + } + } + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + /** + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * Reading functions + */ + + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + protected: + uint32_t readStringBody(std::string& str, int32_t sz); + + int32_t string_limit_; + int32_t container_limit_; + + // Enforce presence of version identifier + bool strict_read_; + bool strict_write_; + + // Buffer for reading strings, save for the lifetime of the protocol to + // avoid memory churn allocating memory on every string read + uint8_t* string_buf_; + int32_t string_buf_size_; + +}; + +/** + * Constructs binary protocol handlers + */ +class TBinaryProtocolFactory : public TProtocolFactory { + public: + TBinaryProtocolFactory() : + string_limit_(0), + container_limit_(0), + strict_read_(false), + strict_write_(true) {} + + TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit, bool strict_read, bool strict_write) : + string_limit_(string_limit), + container_limit_(container_limit), + strict_read_(strict_read), + strict_write_(strict_write) {} + + virtual ~TBinaryProtocolFactory() {} + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_, strict_read_, strict_write_)); + } + + private: + int32_t string_limit_; + int32_t container_limit_; + bool strict_read_; + bool strict_write_; + +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TCompactProtocol.cpp b/lib/cpp/src/protocol/TCompactProtocol.cpp new file mode 100644 index 000000000..ce2ee54d2 --- /dev/null +++ b/lib/cpp/src/protocol/TCompactProtocol.cpp @@ -0,0 +1,736 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TCompactProtocol.h" + +#include <config.h> +#include <limits> + +/* + * TCompactProtocol::i*ToZigzag depend on the fact that the right shift + * operator on a signed integer is an arithmetic (sign-extending) shift. + * If this is not the case, the current implementation will not work. + * If anyone encounters this error, we can try to figure out the best + * way to implement an arithmetic right shift on their platform. + */ +#if !defined(SIGNED_RIGHT_SHIFT_IS) || !defined(ARITHMETIC_RIGHT_SHIFT) +# error "Unable to determine the behavior of a signed right shift" +#endif +#if SIGNED_RIGHT_SHIFT_IS != ARITHMETIC_RIGHT_SHIFT +# error "TCompactProtocol currenly only works if a signed right shift is arithmetic" +#endif + +#ifdef __GNUC__ +#define UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace protocol { + +const int8_t TCompactProtocol::TTypeToCType[16] = { + CT_STOP, // T_STOP + 0, // unused + CT_BOOLEAN_TRUE, // T_BOOL + CT_BYTE, // T_BYTE + CT_DOUBLE, // T_DOUBLE + 0, // unused + CT_I16, // T_I16 + 0, // unused + CT_I32, // T_I32 + 0, // unused + CT_I64, // T_I64 + CT_BINARY, // T_STRING + CT_STRUCT, // T_STRUCT + CT_MAP, // T_MAP + CT_SET, // T_SET + CT_LIST, // T_LIST + }; + + +uint32_t TCompactProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + uint32_t wsize = 0; + wsize += writeByte(PROTOCOL_ID); + wsize += writeByte((VERSION_N & VERSION_MASK) | (((int32_t)messageType << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); + wsize += writeVarint32(seqid); + wsize += writeString(name); + return wsize; +} + +/** + * Write a field header containing the field id and field type. If the + * difference between the current field id and the last one is small (< 15), + * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the + * field id will follow the type header as a zigzag varint. + */ +uint32_t TCompactProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + if (fieldType == T_BOOL) { + booleanField_.name = name; + booleanField_.fieldType = fieldType; + booleanField_.fieldId = fieldId; + } else { + return writeFieldBeginInternal(name, fieldType, fieldId, -1); + } + return 0; +} + +/** + * Write the STOP symbol so we know there are no more fields in this struct. + */ +uint32_t TCompactProtocol::writeFieldStop() { + return writeByte(T_STOP); +} + +/** + * Write a struct begin. This doesn't actually put anything on the wire. We + * use it as an opportunity to put special placeholder markers on the field + * stack so we can get the field id deltas correct. + */ +uint32_t TCompactProtocol::writeStructBegin(const char* name) { + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return 0; +} + +/** + * Write a struct end. This doesn't actually put anything on the wire. We use + * this as an opportunity to pop the last field from the current struct off + * of the field stack. + */ +uint32_t TCompactProtocol::writeStructEnd() { + lastFieldId_ = lastField_.top(); + lastField_.pop(); + return 0; +} + +/** + * Write a List header. + */ +uint32_t TCompactProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + return writeCollectionBegin(elemType, size); +} + +/** + * Write a set header. + */ +uint32_t TCompactProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + return writeCollectionBegin(elemType, size); +} + +/** + * Write a map header. If the map is empty, omit the key and value type + * headers, as we don't need any additional information to skip it. + */ +uint32_t TCompactProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t wsize = 0; + + if (size == 0) { + wsize += writeByte(0); + } else { + wsize += writeVarint32(size); + wsize += writeByte(getCompactType(keyType) << 4 | getCompactType(valType)); + } + return wsize; +} + +/** + * Write a boolean value. Potentially, this could be a boolean field, in + * which case the field header info isn't written yet. If so, decide what the + * right type header is for the value and then write the field header. + * Otherwise, write a single byte. + */ +uint32_t TCompactProtocol::writeBool(const bool value) { + uint32_t wsize = 0; + + if (booleanField_.name != NULL) { + // we haven't written the field header yet + wsize += writeFieldBeginInternal(booleanField_.name, + booleanField_.fieldType, + booleanField_.fieldId, + value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + booleanField_.name = NULL; + } else { + // we're not part of a field, so just write the value + wsize += writeByte(value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + } + return wsize; +} + +uint32_t TCompactProtocol::writeByte(const int8_t byte) { + trans_->write((uint8_t*)&byte, 1); + return 1; +} + +/** + * Write an i16 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI16(const int16_t i16) { + return writeVarint32(i32ToZigzag(i16)); +} + +/** + * Write an i32 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI32(const int32_t i32) { + return writeVarint32(i32ToZigzag(i32)); +} + +/** + * Write an i64 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI64(const int64_t i64) { + return writeVarint64(i64ToZigzag(i64)); +} + +/** + * Write a double to the wire as 8 bytes. + */ +uint32_t TCompactProtocol::writeDouble(const double dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits = bitwise_cast<uint64_t>(dub); + bits = htolell(bits); + trans_->write((uint8_t*)&bits, 8); + return 8; +} + +/** + * Write a string to the wire with a varint size preceeding. + */ +uint32_t TCompactProtocol::writeString(const std::string& str) { + return writeBinary(str); +} + +uint32_t TCompactProtocol::writeBinary(const std::string& str) { + uint32_t ssize = str.size(); + uint32_t wsize = writeVarint32(ssize) + ssize; + trans_->write((uint8_t*)str.data(), ssize); + return wsize; +} + +// +// Internal Writing methods +// + +/** + * The workhorse of writeFieldBegin. It has the option of doing a + * 'type override' of the type header. This is used specifically in the + * boolean field case. + */ +int32_t TCompactProtocol::writeFieldBeginInternal(const char* name, + const TType fieldType, + const int16_t fieldId, + int8_t typeOverride) { + uint32_t wsize = 0; + + // if there's a type override, use that. + int8_t typeToWrite = (typeOverride == -1 ? getCompactType(fieldType) : typeOverride); + + // check if we can use delta encoding for the field id + if (fieldId > lastFieldId_ && fieldId - lastFieldId_ <= 15) { + // write them together + wsize += writeByte((fieldId - lastFieldId_) << 4 | typeToWrite); + } else { + // write them separate + wsize += writeByte(typeToWrite); + wsize += writeI16(fieldId); + } + + lastFieldId_ = fieldId; + return wsize; +} + +/** + * Abstract method for writing the start of lists and sets. List and sets on + * the wire differ only by the type indicator. + */ +uint32_t TCompactProtocol::writeCollectionBegin(int8_t elemType, int32_t size) { + uint32_t wsize = 0; + if (size <= 14) { + wsize += writeByte(size << 4 | getCompactType(elemType)); + } else { + wsize += writeByte(0xf0 | getCompactType(elemType)); + wsize += writeVarint32(size); + } + return wsize; +} + +/** + * Write an i32 as a varint. Results in 1-5 bytes on the wire. + */ +uint32_t TCompactProtocol::writeVarint32(uint32_t n) { + uint8_t buf[5]; + uint32_t wsize = 0; + + while (true) { + if ((n & ~0x7F) == 0) { + buf[wsize++] = (int8_t)n; + break; + } else { + buf[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + trans_->write(buf, wsize); + return wsize; +} + +/** + * Write an i64 as a varint. Results in 1-10 bytes on the wire. + */ +uint32_t TCompactProtocol::writeVarint64(uint64_t n) { + uint8_t buf[10]; + uint32_t wsize = 0; + + while (true) { + if ((n & ~0x7FL) == 0) { + buf[wsize++] = (int8_t)n; + break; + } else { + buf[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + trans_->write(buf, wsize); + return wsize; +} + +/** + * Convert l into a zigzag long. This allows negative numbers to be + * represented compactly as a varint. + */ +uint64_t TCompactProtocol::i64ToZigzag(const int64_t l) { + return (l << 1) ^ (l >> 63); +} + +/** + * Convert n into a zigzag int. This allows negative numbers to be + * represented compactly as a varint. + */ +uint32_t TCompactProtocol::i32ToZigzag(const int32_t n) { + return (n << 1) ^ (n >> 31); +} + +/** + * Given a TType value, find the appropriate TCompactProtocol.Type value + */ +int8_t TCompactProtocol::getCompactType(int8_t ttype) { + return TTypeToCType[ttype]; +} + +// +// Reading Methods +// + +/** + * Read a message header. + */ +uint32_t TCompactProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t rsize = 0; + int8_t protocolId; + int8_t versionAndType; + int8_t version; + + rsize += readByte(protocolId); + if (protocolId != PROTOCOL_ID) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol identifier"); + } + + rsize += readByte(versionAndType); + version = (int8_t)(versionAndType & VERSION_MASK); + if (version != VERSION_N) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol version"); + } + + messageType = (TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & 0x03); + rsize += readVarint32(seqid); + rsize += readString(name); + + return rsize; +} + +/** + * Read a struct begin. There's nothing on the wire for this, but it is our + * opportunity to push a new struct begin marker on the field stack. + */ +uint32_t TCompactProtocol::readStructBegin(std::string& name) { + name = ""; + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return 0; +} + +/** + * Doesn't actually consume any wire data, just removes the last field for + * this struct from the field stack. + */ +uint32_t TCompactProtocol::readStructEnd() { + lastFieldId_ = lastField_.top(); + lastField_.pop(); + return 0; +} + +/** + * Read a field header off the wire. + */ +uint32_t TCompactProtocol::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t rsize = 0; + int8_t byte; + int8_t type; + + rsize += readByte(byte); + type = (byte & 0x0f); + + // if it's a stop, then we can return immediately, as the struct is over. + if (type == T_STOP) { + fieldType = T_STOP; + fieldId = 0; + return rsize; + } + + // mask off the 4 MSB of the type header. it could contain a field id delta. + int16_t modifier = (int16_t)(((uint8_t)byte & 0xf0) >> 4); + if (modifier == 0) { + // not a delta, look ahead for the zigzag varint field id. + rsize += readI16(fieldId); + } else { + fieldId = (int16_t)(lastFieldId_ + modifier); + } + fieldType = getTType(type); + + // if this happens to be a boolean field, the value is encoded in the type + if (type == CT_BOOLEAN_TRUE || type == CT_BOOLEAN_FALSE) { + // save the boolean value in a special instance variable. + boolValue_.hasBoolValue = true; + boolValue_.boolValue = (type == CT_BOOLEAN_TRUE ? true : false); + } + + // push the new field onto the field stack so we can keep the deltas going. + lastFieldId_ = fieldId; + return rsize; +} + +/** + * Read a map header off the wire. If the size is zero, skip reading the key + * and value type. This means that 0-length maps will yield TMaps without the + * "correct" types. + */ +uint32_t TCompactProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint32_t rsize = 0; + int8_t kvType = 0; + int32_t msize = 0; + + rsize += readVarint32(msize); + if (msize != 0) + rsize += readByte(kvType); + + if (msize < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && msize > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + keyType = getTType((int8_t)((uint8_t)kvType >> 4)); + valType = getTType((int8_t)((uint8_t)kvType & 0xf)); + size = (uint32_t)msize; + + return rsize; +} + +/** + * Read a list header off the wire. If the list size is 0-14, the size will + * be packed into the element type header. If it's a longer list, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ +uint32_t TCompactProtocol::readListBegin(TType& elemType, + uint32_t& size) { + int8_t size_and_type; + uint32_t rsize = 0; + int32_t lsize; + + rsize += readByte(size_and_type); + + lsize = ((uint8_t)size_and_type >> 4) & 0x0f; + if (lsize == 15) { + rsize += readVarint32(lsize); + } + + if (lsize < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && lsize > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + elemType = getTType((int8_t)(size_and_type & 0x0f)); + size = (uint32_t)lsize; + + return rsize; +} + +/** + * Read a set header off the wire. If the set size is 0-14, the size will + * be packed into the element type header. If it's a longer set, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ +uint32_t TCompactProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + return readListBegin(elemType, size); +} + +/** + * Read a boolean off the wire. If this is a boolean field, the value should + * already have been read during readFieldBegin, so we'll just consume the + * pre-stored value. Otherwise, read a byte. + */ +uint32_t TCompactProtocol::readBool(bool& value) { + if (boolValue_.hasBoolValue == true) { + value = boolValue_.boolValue; + boolValue_.hasBoolValue = false; + return 0; + } else { + int8_t val; + readByte(val); + value = (val == CT_BOOLEAN_TRUE); + return 1; + } +} + +/** + * Read a single byte off the wire. Nothing interesting here. + */ +uint32_t TCompactProtocol::readByte(int8_t& byte) { + uint8_t b[1]; + trans_->readAll(b, 1); + byte = *(int8_t*)b; + return 1; +} + +/** + * Read an i16 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI16(int16_t& i16) { + int32_t value; + uint32_t rsize = readVarint32(value); + i16 = (int16_t)zigzagToI32(value); + return rsize; +} + +/** + * Read an i32 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI32(int32_t& i32) { + int32_t value; + uint32_t rsize = readVarint32(value); + i32 = zigzagToI32(value); + return rsize; +} + +/** + * Read an i64 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI64(int64_t& i64) { + int64_t value; + uint32_t rsize = readVarint64(value); + i64 = zigzagToI64(value); + return rsize; +} + +/** + * No magic here - just read a double off the wire. + */ +uint32_t TCompactProtocol::readDouble(double& dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits; + uint8_t b[8]; + trans_->readAll(b, 8); + bits = *(uint64_t*)b; + bits = letohll(bits); + dub = bitwise_cast<double>(bits); + return 8; +} + +uint32_t TCompactProtocol::readString(std::string& str) { + return readBinary(str); +} + +/** + * Read a byte[] from the wire. + */ +uint32_t TCompactProtocol::readBinary(std::string& str) { + int32_t rsize = 0; + int32_t size; + + rsize += readVarint32(size); + // Catch empty string case + if (size == 0) { + str = ""; + return rsize; + } + + // Catch error cases + if (size < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } + if (string_limit_ > 0 && size > string_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + // Use the heap here to prevent stack overflow for v. large strings + if (size > string_buf_size_ || string_buf_ == NULL) { + void* new_string_buf = std::realloc(string_buf_, (uint32_t)size); + if (new_string_buf == NULL) { + throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TCompactProtocol::readString"); + } + string_buf_ = (uint8_t*)new_string_buf; + string_buf_size_ = size; + } + trans_->readAll(string_buf_, size); + str.assign((char*)string_buf_, size); + + return rsize + (uint32_t)size; +} + +/** + * Read an i32 from the wire as a varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 5 bytes. + */ +uint32_t TCompactProtocol::readVarint32(int32_t& i32) { + int64_t val; + uint32_t rsize = readVarint64(val); + i32 = (int32_t)val; + return rsize; +} + +/** + * Read an i64 from the wire as a proper varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 10 bytes. + */ +uint32_t TCompactProtocol::readVarint64(int64_t& i64) { + uint32_t rsize = 0; + uint64_t val = 0; + int shift = 0; + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + uint32_t buf_size = sizeof(buf); + const uint8_t* borrowed = trans_->borrow(buf, &buf_size); + + // Fast path. + if (borrowed != NULL) { + while (true) { + uint8_t byte = borrowed[rsize]; + rsize++; + val |= (uint64_t)(byte & 0x7f) << shift; + shift += 7; + if (!(byte & 0x80)) { + i64 = val; + trans_->consume(rsize); + return rsize; + } + // Have to check for invalid data so we don't crash. + if (UNLIKELY(rsize == sizeof(buf))) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } + + // Slow path. + else { + while (true) { + uint8_t byte; + rsize += trans_->readAll(&byte, 1); + val |= (uint64_t)(byte & 0x7f) << shift; + shift += 7; + if (!(byte & 0x80)) { + i64 = val; + return rsize; + } + // Might as well check for invalid data on the slow path too. + if (UNLIKELY(rsize >= sizeof(buf))) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } +} + +/** + * Convert from zigzag int to int. + */ +int32_t TCompactProtocol::zigzagToI32(uint32_t n) { + return (n >> 1) ^ -(n & 1); +} + +/** + * Convert from zigzag long to long. + */ +int64_t TCompactProtocol::zigzagToI64(uint64_t n) { + return (n >> 1) ^ -(n & 1); +} + +TType TCompactProtocol::getTType(int8_t type) { + switch (type) { + case T_STOP: + return T_STOP; + case CT_BOOLEAN_FALSE: + case CT_BOOLEAN_TRUE: + return T_BOOL; + case CT_BYTE: + return T_BYTE; + case CT_I16: + return T_I16; + case CT_I32: + return T_I32; + case CT_I64: + return T_I64; + case CT_DOUBLE: + return T_DOUBLE; + case CT_BINARY: + return T_STRING; + case CT_LIST: + return T_LIST; + case CT_SET: + return T_SET; + case CT_MAP: + return T_MAP; + case CT_STRUCT: + return T_STRUCT; + default: + throw TException("don't know what type: " + type); + } + return T_STOP; +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TCompactProtocol.h b/lib/cpp/src/protocol/TCompactProtocol.h new file mode 100644 index 000000000..b4e06f0aa --- /dev/null +++ b/lib/cpp/src/protocol/TCompactProtocol.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include <stack> +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace protocol { + +/** + * C++ Implementation of the Compact Protocol as described in THRIFT-110 + */ +class TCompactProtocol : public TProtocol { + + protected: + static const int8_t PROTOCOL_ID = 0x82; + static const int8_t VERSION_N = 1; + static const int8_t VERSION_MASK = 0x1f; // 0001 1111 + static const int8_t TYPE_MASK = 0xE0; // 1110 0000 + static const int32_t TYPE_SHIFT_AMOUNT = 5; + + /** + * (Writing) If we encounter a boolean field begin, save the TField here + * so it can have the value incorporated. + */ + struct { + const char* name; + TType fieldType; + int16_t fieldId; + } booleanField_; + + /** + * (Reading) If we read a field header, and it's a boolean field, save + * the boolean value here so that readBool can use it. + */ + struct { + bool hasBoolValue; + bool boolValue; + } boolValue_; + + /** + * Used to keep track of the last field for the current and previous structs, + * so we can do the delta stuff. + */ + + std::stack<int16_t> lastField_; + int16_t lastFieldId_; + + enum Types { + CT_STOP = 0x00, + CT_BOOLEAN_TRUE = 0x01, + CT_BOOLEAN_FALSE = 0x02, + CT_BYTE = 0x03, + CT_I16 = 0x04, + CT_I32 = 0x05, + CT_I64 = 0x06, + CT_DOUBLE = 0x07, + CT_BINARY = 0x08, + CT_LIST = 0x09, + CT_SET = 0x0A, + CT_MAP = 0x0B, + CT_STRUCT = 0x0C, + }; + + static const int8_t TTypeToCType[16]; + + public: + TCompactProtocol(boost::shared_ptr<TTransport> trans) : + TProtocol(trans), + lastFieldId_(0), + string_limit_(0), + string_buf_(NULL), + string_buf_size_(0), + container_limit_(0) { + booleanField_.name = NULL; + boolValue_.hasBoolValue = false; + } + + TCompactProtocol(boost::shared_ptr<TTransport> trans, + int32_t string_limit, + int32_t container_limit) : + TProtocol(trans), + lastFieldId_(0), + string_limit_(string_limit), + string_buf_(NULL), + string_buf_size_(0), + container_limit_(container_limit) { + booleanField_.name = NULL; + boolValue_.hasBoolValue = false; + } + + + + /** + * Writing functions + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldStop(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * These methods are called by structs, but don't actually have any wired + * output or purpose + */ + virtual uint32_t writeMessageEnd() { return 0; } + uint32_t writeMapEnd() { return 0; } + uint32_t writeListEnd() { return 0; } + uint32_t writeSetEnd() { return 0; } + uint32_t writeFieldEnd() { return 0; } + + protected: + int32_t writeFieldBeginInternal(const char* name, + const TType fieldType, + const int16_t fieldId, + int8_t typeOverride); + uint32_t writeCollectionBegin(int8_t elemType, int32_t size); + uint32_t writeVarint32(uint32_t n); + uint32_t writeVarint64(uint64_t n); + uint64_t i64ToZigzag(const int64_t l); + uint32_t i32ToZigzag(const int32_t n); + inline int8_t getCompactType(int8_t ttype); + + public: + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + /* + *These methods are here for the struct to call, but don't have any wire + * encoding. + */ + uint32_t readMessageEnd() { return 0; } + uint32_t readFieldEnd() { return 0; } + uint32_t readMapEnd() { return 0; } + uint32_t readListEnd() { return 0; } + uint32_t readSetEnd() { return 0; } + + protected: + uint32_t readVarint32(int32_t& i32); + uint32_t readVarint64(int64_t& i64); + int32_t zigzagToI32(uint32_t n); + int64_t zigzagToI64(uint64_t n); + TType getTType(int8_t type); + + // Buffer for reading strings, save for the lifetime of the protocol to + // avoid memory churn allocating memory on every string read + int32_t string_limit_; + uint8_t* string_buf_; + int32_t string_buf_size_; + int32_t container_limit_; +}; + +/** + * Constructs compact protocol handlers + */ +class TCompactProtocolFactory : public TProtocolFactory { + public: + TCompactProtocolFactory() : + string_limit_(0), + container_limit_(0) {} + + TCompactProtocolFactory(int32_t string_limit, int32_t container_limit) : + string_limit_(string_limit), + container_limit_(container_limit) {} + + virtual ~TCompactProtocolFactory() {} + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TCompactProtocol(trans, string_limit_, container_limit_)); + } + + private: + int32_t string_limit_; + int32_t container_limit_; + +}; + +}}} // apache::thrift::protocol + +#endif diff --git a/lib/cpp/src/protocol/TDebugProtocol.cpp b/lib/cpp/src/protocol/TDebugProtocol.cpp new file mode 100644 index 000000000..40aa36bad --- /dev/null +++ b/lib/cpp/src/protocol/TDebugProtocol.cpp @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TDebugProtocol.h" + +#include <cassert> +#include <cctype> +#include <cstdio> +#include <stdexcept> +#include <boost/static_assert.hpp> +#include <boost/lexical_cast.hpp> + +using std::string; + + +static string byte_to_hex(const uint8_t byte) { + char buf[3]; + int ret = std::sprintf(buf, "%02x", (int)byte); + assert(ret == 2); + assert(buf[2] == '\0'); + return buf; +} + + +namespace apache { namespace thrift { namespace protocol { + +string TDebugProtocol::fieldTypeName(TType type) { + switch (type) { + case T_STOP : return "stop" ; + case T_VOID : return "void" ; + case T_BOOL : return "bool" ; + case T_BYTE : return "byte" ; + case T_I16 : return "i16" ; + case T_I32 : return "i32" ; + case T_U64 : return "u64" ; + case T_I64 : return "i64" ; + case T_DOUBLE : return "double" ; + case T_STRING : return "string" ; + case T_STRUCT : return "struct" ; + case T_MAP : return "map" ; + case T_SET : return "set" ; + case T_LIST : return "list" ; + case T_UTF8 : return "utf8" ; + case T_UTF16 : return "utf16" ; + default: return "unknown"; + } +} + +void TDebugProtocol::indentUp() { + indent_str_ += string(indent_inc, ' '); +} + +void TDebugProtocol::indentDown() { + if (indent_str_.length() < (string::size_type)indent_inc) { + throw TProtocolException(TProtocolException::INVALID_DATA); + } + indent_str_.erase(indent_str_.length() - indent_inc); +} + +uint32_t TDebugProtocol::writePlain(const string& str) { + trans_->write((uint8_t*)str.data(), str.length()); + return str.length(); +} + +uint32_t TDebugProtocol::writeIndented(const string& str) { + trans_->write((uint8_t*)indent_str_.data(), indent_str_.length()); + trans_->write((uint8_t*)str.data(), str.length()); + return indent_str_.length() + str.length(); +} + +uint32_t TDebugProtocol::startItem() { + uint32_t size; + + switch (write_state_.back()) { + case UNINIT: + // XXX figure out what to do here. + //throw TProtocolException(TProtocolException::INVALID_DATA); + //return writeIndented(str); + return 0; + case STRUCT: + return 0; + case SET: + return writeIndented(""); + case MAP_KEY: + return writeIndented(""); + case MAP_VALUE: + return writePlain(" -> "); + case LIST: + size = writeIndented( + "[" + boost::lexical_cast<string>(list_idx_.back()) + "] = "); + list_idx_.back()++; + return size; + default: + throw std::logic_error("Invalid enum value."); + } +} + +uint32_t TDebugProtocol::endItem() { + //uint32_t size; + + switch (write_state_.back()) { + case UNINIT: + // XXX figure out what to do here. + //throw TProtocolException(TProtocolException::INVALID_DATA); + //return writeIndented(str); + return 0; + case STRUCT: + return writePlain(",\n"); + case SET: + return writePlain(",\n"); + case MAP_KEY: + write_state_.back() = MAP_VALUE; + return 0; + case MAP_VALUE: + write_state_.back() = MAP_KEY; + return writePlain(",\n"); + case LIST: + return writePlain(",\n"); + default: + throw std::logic_error("Invalid enum value."); + } +} + +uint32_t TDebugProtocol::writeItem(const std::string& str) { + uint32_t size = 0; + size += startItem(); + size += writePlain(str); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + string mtype; + switch (messageType) { + case T_CALL : mtype = "call" ; break; + case T_REPLY : mtype = "reply" ; break; + case T_EXCEPTION : mtype = "exn" ; break; + } + + uint32_t size = writeIndented("(" + mtype + ") " + name + "("); + indentUp(); + return size; +} + +uint32_t TDebugProtocol::writeMessageEnd() { + indentDown(); + return writeIndented(")\n"); +} + +uint32_t TDebugProtocol::writeStructBegin(const char* name) { + uint32_t size = 0; + size += startItem(); + size += writePlain(string(name) + " {\n"); + indentUp(); + write_state_.push_back(STRUCT); + return size; +} + +uint32_t TDebugProtocol::writeStructEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + // sprintf(id_str, "%02d", fieldId); + string id_str = boost::lexical_cast<string>(fieldId); + if (id_str.length() == 1) id_str = '0' + id_str; + + return writeIndented( + id_str + ": " + + name + " (" + + fieldTypeName(fieldType) + ") = "); +} + +uint32_t TDebugProtocol::writeFieldEnd() { + assert(write_state_.back() == STRUCT); + return 0; +} + +uint32_t TDebugProtocol::writeFieldStop() { + return 0; + //writeIndented("***STOP***\n"); +} + +uint32_t TDebugProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + // TODO(dreiss): Optimize short maps? + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "map<" + fieldTypeName(keyType) + "," + fieldTypeName(valType) + ">" + "[" + boost::lexical_cast<string>(size) + "] {\n"); + indentUp(); + write_state_.push_back(MAP_KEY); + return bsize; +} + +uint32_t TDebugProtocol::writeMapEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + // TODO(dreiss): Optimize short arrays. + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "list<" + fieldTypeName(elemType) + ">" + "[" + boost::lexical_cast<string>(size) + "] {\n"); + indentUp(); + write_state_.push_back(LIST); + list_idx_.push_back(0); + return bsize; +} + +uint32_t TDebugProtocol::writeListEnd() { + indentDown(); + write_state_.pop_back(); + list_idx_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + // TODO(dreiss): Optimize short sets. + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "set<" + fieldTypeName(elemType) + ">" + "[" + boost::lexical_cast<string>(size) + "] {\n"); + indentUp(); + write_state_.push_back(SET); + return bsize; +} + +uint32_t TDebugProtocol::writeSetEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeBool(const bool value) { + return writeItem(value ? "true" : "false"); +} + +uint32_t TDebugProtocol::writeByte(const int8_t byte) { + return writeItem("0x" + byte_to_hex(byte)); +} + +uint32_t TDebugProtocol::writeI16(const int16_t i16) { + return writeItem(boost::lexical_cast<string>(i16)); +} + +uint32_t TDebugProtocol::writeI32(const int32_t i32) { + return writeItem(boost::lexical_cast<string>(i32)); +} + +uint32_t TDebugProtocol::writeI64(const int64_t i64) { + return writeItem(boost::lexical_cast<string>(i64)); +} + +uint32_t TDebugProtocol::writeDouble(const double dub) { + return writeItem(boost::lexical_cast<string>(dub)); +} + + +uint32_t TDebugProtocol::writeString(const string& str) { + // XXX Raw/UTF-8? + + string to_show = str; + if (to_show.length() > (string::size_type)string_limit_) { + to_show = str.substr(0, string_prefix_size_); + to_show += "[...](" + boost::lexical_cast<string>(str.length()) + ")"; + } + + string output = "\""; + + for (string::const_iterator it = to_show.begin(); it != to_show.end(); ++it) { + if (*it == '\\') { + output += "\\\\"; + } else if (*it == '"') { + output += "\\\""; + } else if (std::isprint(*it)) { + output += *it; + } else { + switch (*it) { + case '\a': output += "\\a"; break; + case '\b': output += "\\b"; break; + case '\f': output += "\\f"; break; + case '\n': output += "\\n"; break; + case '\r': output += "\\r"; break; + case '\t': output += "\\t"; break; + case '\v': output += "\\v"; break; + default: + output += "\\x"; + output += byte_to_hex(*it); + } + } + } + + output += '\"'; + return writeItem(output); +} + +uint32_t TDebugProtocol::writeBinary(const string& str) { + // XXX Hex? + return TDebugProtocol::writeString(str); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TDebugProtocol.h b/lib/cpp/src/protocol/TDebugProtocol.h new file mode 100644 index 000000000..ab69e0ca5 --- /dev/null +++ b/lib/cpp/src/protocol/TDebugProtocol.h @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ 1 + +#include "TProtocol.h" +#include "TOneWayProtocol.h" + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace protocol { + +/* + +!!! EXPERIMENTAL CODE !!! + +This protocol is very much a work in progress. +It doesn't handle many cases properly. +It throws exceptions in many cases. +It probably segfaults in many cases. +Bug reports and feature requests are welcome. +Complaints are not. :R + +*/ + + +/** + * Protocol that prints the payload in a nice human-readable format. + * Reading from this protocol is not supported. + * + */ +class TDebugProtocol : public TWriteOnlyProtocol { + private: + enum write_state_t + { UNINIT + , STRUCT + , LIST + , SET + , MAP_KEY + , MAP_VALUE + }; + + public: + TDebugProtocol(boost::shared_ptr<TTransport> trans) + : TWriteOnlyProtocol(trans, "TDebugProtocol") + , string_limit_(DEFAULT_STRING_LIMIT) + , string_prefix_size_(DEFAULT_STRING_PREFIX_SIZE) + { + write_state_.push_back(UNINIT); + } + + static const int32_t DEFAULT_STRING_LIMIT = 256; + static const int32_t DEFAULT_STRING_PREFIX_SIZE = 16; + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setStringPrefixSize(int32_t string_prefix_size) { + string_prefix_size_ = string_prefix_size; + } + + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + + private: + void indentUp(); + void indentDown(); + uint32_t writePlain(const std::string& str); + uint32_t writeIndented(const std::string& str); + uint32_t startItem(); + uint32_t endItem(); + uint32_t writeItem(const std::string& str); + + static std::string fieldTypeName(TType type); + + int32_t string_limit_; + int32_t string_prefix_size_; + + std::string indent_str_; + static const int indent_inc = 2; + + std::vector<write_state_t> write_state_; + std::vector<int> list_idx_; +}; + +/** + * Constructs debug protocol handlers + */ +class TDebugProtocolFactory : public TProtocolFactory { + public: + TDebugProtocolFactory() {} + virtual ~TDebugProtocolFactory() {} + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TDebugProtocol(trans)); + } + +}; + +}}} // apache::thrift::protocol + + +// TODO(dreiss): Move (part of) ThriftDebugString into a .cpp file and remove this. +#include <transport/TBufferTransports.h> + +namespace apache { namespace thrift { + +template<typename ThriftStruct> +std::string ThriftDebugString(const ThriftStruct& ts) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr<TTransport> trans(buffer); + TDebugProtocol protocol(trans); + + ts.write(&protocol); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} + +// TODO(dreiss): This is badly broken. Don't use it unless you are me. +#if 0 +template<typename Object> +std::string DebugString(const std::vector<Object>& vec) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr<TTransport> trans(buffer); + TDebugProtocol protocol(trans); + + // I am gross! + protocol.writeStructBegin("SomeRandomVector"); + + // TODO: Fix this with a trait. + protocol.writeListBegin((TType)99, vec.size()); + typename std::vector<Object>::const_iterator it; + for (it = vec.begin(); it != vec.end(); ++it) { + it->write(&protocol); + } + protocol.writeListEnd(); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} +#endif // 0 + +}} // apache::thrift + + +#endif // #ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ + + diff --git a/lib/cpp/src/protocol/TDenseProtocol.cpp b/lib/cpp/src/protocol/TDenseProtocol.cpp new file mode 100644 index 000000000..8e76dc479 --- /dev/null +++ b/lib/cpp/src/protocol/TDenseProtocol.cpp @@ -0,0 +1,762 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + +IMPLEMENTATION DETAILS + +TDenseProtocol was designed to have a smaller serialized form than +TBinaryProtocol. This is accomplished using two techniques. The first is +variable-length integer encoding. We use the same technique that the Standard +MIDI File format uses for "variable-length quantities" +(http://en.wikipedia.org/wiki/Variable-length_quantity). +All integers (including i16, but not byte) are first cast to uint64_t, +then written out as variable-length quantities. This has the unfortunate side +effect that all negative numbers require 10 bytes, but negative numbers tend +to be far less common than positive ones. + +The second technique eliminating the field ids used by TBinaryProtocol. This +decision required support from the Thrift compiler and also sacrifices some of +the backward and forward compatibility of TBinaryProtocol. + +We considered implementing this technique by generating separate readers and +writers for the dense protocol (this is how Pillar, Thrift's predecessor, +worked), but this idea had a few problems: +- Our abstractions go out the window. +- We would have to maintain a second code generator. +- Preserving compatibility with old versions of the structures would be a + nightmare. + +Therefore, we chose an alternate implementation that stored the description of +the data neither in the data itself (like TBinaryProtocol) nor in the +serialization code (like Pillar), but instead in a separate data structure, +called a TypeSpec. TypeSpecs are generated by the Thrift compiler +(specifically in the t_cpp_generator), and their structure should be +documented there (TODO(dreiss): s/should be/is/). + +We maintain a stack of TypeSpecs within the protocol so it knows where the +generated code is in the reading/writing process. For example, if we are +writing an i32 contained in a struct bar, contained in a struct foo, then the +stack would look like: TOP , i32 , struct bar , struct foo , BOTTOM. +The following invariant: whenever we are about to read/write an object +(structBegin, containerBegin, or a scalar), the TypeSpec on the top of the +stack must match the type being read/written. The main reasons that this +invariant must be maintained is that if we ever start reading a structure, we +must have its exact TypeSpec in order to pass the right tags to the +deserializer. + +We use the following strategies for maintaining this invariant: + +- For structures, we have a separate stack of indexes, one for each structure + on the TypeSpec stack. These are indexes into the list of fields in the + structure's TypeSpec. When we {read,write}FieldBegin, we push on the + TypeSpec for the field. +- When we begin writing a list or set, we push on the TypeSpec for the + element type. +- For maps, we have a separate stack of booleans, one for each map on the + TypeSpec stack. The boolean is true if we are writing the key for that + map, and false if we are writing the value. Maps are the trickiest case + because the generated code does not call any protocol method between + the key and the value. As a result, we potentially have to switch + between map key state and map value state after reading/writing any object. +- This job is handled by the stateTransition method. It is called after + reading/writing every object. It pops the current TypeSpec off the stack, + then optionally pushes a new one on, depending on what the next TypeSpec is. + If it is a struct, the job is left to the next writeFieldBegin. If it is a + set or list, the just-popped typespec is pushed back on. If it is a map, + the top of the key/value stack is toggled, and the appropriate TypeSpec + is pushed. + +Optional fields are a little tricky also. We write a zero byte if they are +absent and prefix them with an 0x01 byte if they are present +*/ + +#define __STDC_LIMIT_MACROS +#include <stdint.h> +#include "TDenseProtocol.h" +#include "TReflectionLocal.h" + +// Leaving this on for now. Disabling it will turn off asserts, which should +// give a performance boost. When we have *really* thorough test cases, +// we should drop this. +#define DEBUG_TDENSEPROTOCOL + +// NOTE: Assertions should *only* be used to detect bugs in code, +// either in TDenseProtocol itself, or in code using it. +// (For example, using the wrong TypeSpec.) +// Invalid data should NEVER cause an assertion failure, +// no matter how grossly corrupted, nor how ingeniously crafted. +#ifdef DEBUG_TDENSEPROTOCOL +#undef NDEBUG +#else +#define NDEBUG +#endif +#include <cassert> + +using std::string; + +#ifdef __GNUC__ +#define UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace protocol { + +const int TDenseProtocol::FP_PREFIX_LEN = + apache::thrift::reflection::local::FP_PREFIX_LEN; + +// Top TypeSpec. TypeSpec of the structure being encoded. +#define TTS (ts_stack_.back()) // type = TypeSpec* +// InDeX. Index into TTS of the current/next field to encode. +#define IDX (idx_stack_.back()) // type = int +// Field TypeSpec. TypeSpec of the current/next field to encode. +#define FTS (TTS->tstruct.specs[IDX]) // type = TypeSpec* +// Field MeTa. Metadata of the current/next field to encode. +#define FMT (TTS->tstruct.metas[IDX]) // type = FieldMeta +// SubType 1/2. TypeSpec of the first/second subtype of this container. +#define ST1 (TTS->tcontainer.subtype1) +#define ST2 (TTS->tcontainer.subtype2) + + +/** + * Checks that @c ttype is indeed the ttype that we should be writing, + * according to our typespec. Aborts if the test fails and debugging in on. + */ +inline void TDenseProtocol::checkTType(const TType ttype) { + assert(!ts_stack_.empty()); + assert(TTS->ttype == ttype); +} + +/** + * Makes sure that the TypeSpec stack is correct for the next object. + * See top-of-file comments. + */ +inline void TDenseProtocol::stateTransition() { + TypeSpec* old_tts = ts_stack_.back(); + ts_stack_.pop_back(); + + // If this is the end of the top-level write, we should have just popped + // the TypeSpec passed to the constructor. + if (ts_stack_.empty()) { + assert(old_tts = type_spec_); + return; + } + + switch (TTS->ttype) { + + case T_STRUCT: + assert(old_tts == FTS); + break; + + case T_LIST: + case T_SET: + assert(old_tts == ST1); + ts_stack_.push_back(old_tts); + break; + + case T_MAP: + assert(old_tts == (mkv_stack_.back() ? ST1 : ST2)); + mkv_stack_.back() = !mkv_stack_.back(); + ts_stack_.push_back(mkv_stack_.back() ? ST1 : ST2); + break; + + default: + assert(!"Invalid TType in stateTransition."); + break; + + } +} + + +/* + * Variable-length quantity functions. + */ + +inline uint32_t TDenseProtocol::vlqRead(uint64_t& vlq) { + uint32_t used = 0; + uint64_t val = 0; + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + uint32_t buf_size = sizeof(buf); + const uint8_t* borrowed = trans_->borrow(buf, &buf_size); + + // Fast path. TODO(dreiss): Make it faster. + if (borrowed != NULL) { + while (true) { + uint8_t byte = borrowed[used]; + used++; + val = (val << 7) | (byte & 0x7f); + if (!(byte & 0x80)) { + vlq = val; + trans_->consume(used); + return used; + } + // Have to check for invalid data so we don't crash. + if (UNLIKELY(used == sizeof(buf))) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } + + // Slow path. + else { + while (true) { + uint8_t byte; + used += trans_->readAll(&byte, 1); + val = (val << 7) | (byte & 0x7f); + if (!(byte & 0x80)) { + vlq = val; + return used; + } + // Might as well check for invalid data on the slow path too. + if (UNLIKELY(used >= sizeof(buf))) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } +} + +inline uint32_t TDenseProtocol::vlqWrite(uint64_t vlq) { + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + int32_t pos = sizeof(buf) - 1; + + // Write the thing from back to front. + buf[pos] = vlq & 0x7f; + vlq >>= 7; + pos--; + + while (vlq > 0) { + assert(pos >= 0); + buf[pos] = (vlq | 0x80); + vlq >>= 7; + pos--; + } + + // Back up one step before writing. + pos++; + + trans_->write(buf+pos, sizeof(buf) - pos); + return sizeof(buf) - pos; +} + + + +/* + * Writing functions. + */ + +uint32_t TDenseProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + throw TApplicationException("TDenseProtocol doesn't work with messages (yet)."); + + int32_t version = (VERSION_2) | ((int32_t)messageType); + uint32_t wsize = 0; + wsize += subWriteI32(version); + wsize += subWriteString(name); + wsize += subWriteI32(seqid); + return wsize; +} + +uint32_t TDenseProtocol::writeMessageEnd() { + return 0; +} + +uint32_t TDenseProtocol::writeStructBegin(const char* name) { + uint32_t xfer = 0; + + // The TypeSpec stack should be empty if this is the top-level read/write. + // If it is, we push the TypeSpec passed to the constructor. + if (ts_stack_.empty()) { + assert(standalone_); + + if (type_spec_ == NULL) { + resetState(); + throw TApplicationException("TDenseProtocol: No type specified."); + } else { + assert(type_spec_->ttype == T_STRUCT); + ts_stack_.push_back(type_spec_); + // Write out a prefix of the structure fingerprint. + trans_->write(type_spec_->fp_prefix, FP_PREFIX_LEN); + xfer += FP_PREFIX_LEN; + } + } + + // We need a new field index for this structure. + idx_stack_.push_back(0); + return 0; +} + +uint32_t TDenseProtocol::writeStructEnd() { + idx_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t xfer = 0; + + // Skip over optional fields. + while (FMT.tag != fieldId) { + // TODO(dreiss): Old meta here. + assert(FTS->ttype != T_STOP); + assert(FMT.is_optional); + // Write a zero byte so the reader can skip it. + xfer += subWriteBool(false); + // And advance to the next field. + IDX++; + } + + // TODO(dreiss): give a better exception. + assert(FTS->ttype == fieldType); + + if (FMT.is_optional) { + subWriteBool(true); + xfer += 1; + } + + // writeFieldStop shares all lot of logic up to this point. + // Instead of replicating it all, we just call this method from that one + // and use a gross special case here. + if (UNLIKELY(FTS->ttype != T_STOP)) { + // For normal fields, push the TypeSpec that we're about to use. + ts_stack_.push_back(FTS); + } + return xfer; +} + +uint32_t TDenseProtocol::writeFieldEnd() { + // Just move on to the next field. + IDX++; + return 0; +} + +uint32_t TDenseProtocol::writeFieldStop() { + return TDenseProtocol::writeFieldBegin("", T_STOP, 0); +} + +uint32_t TDenseProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + checkTType(T_MAP); + + assert(keyType == ST1->ttype); + assert(valType == ST2->ttype); + + ts_stack_.push_back(ST1); + mkv_stack_.push_back(true); + + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeMapEnd() { + // Pop off the value type, as well as our entry in the map key/value stack. + // stateTransition takes care of popping off our TypeSpec. + ts_stack_.pop_back(); + mkv_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + checkTType(T_LIST); + + assert(elemType == ST1->ttype); + ts_stack_.push_back(ST1); + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeListEnd() { + // Pop off the element type. stateTransition takes care of popping off ours. + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + checkTType(T_SET); + + assert(elemType == ST1->ttype); + ts_stack_.push_back(ST1); + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeSetEnd() { + // Pop off the element type. stateTransition takes care of popping off ours. + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeBool(const bool value) { + checkTType(T_BOOL); + stateTransition(); + return TBinaryProtocol::writeBool(value); +} + +uint32_t TDenseProtocol::writeByte(const int8_t byte) { + checkTType(T_BYTE); + stateTransition(); + return TBinaryProtocol::writeByte(byte); +} + +uint32_t TDenseProtocol::writeI16(const int16_t i16) { + checkTType(T_I16); + stateTransition(); + return vlqWrite(i16); +} + +uint32_t TDenseProtocol::writeI32(const int32_t i32) { + checkTType(T_I32); + stateTransition(); + return vlqWrite(i32); +} + +uint32_t TDenseProtocol::writeI64(const int64_t i64) { + checkTType(T_I64); + stateTransition(); + return vlqWrite(i64); +} + +uint32_t TDenseProtocol::writeDouble(const double dub) { + checkTType(T_DOUBLE); + stateTransition(); + return TBinaryProtocol::writeDouble(dub); +} + +uint32_t TDenseProtocol::writeString(const std::string& str) { + checkTType(T_STRING); + stateTransition(); + return subWriteString(str); +} + +uint32_t TDenseProtocol::writeBinary(const std::string& str) { + return TDenseProtocol::writeString(str); +} + +inline uint32_t TDenseProtocol::subWriteI32(const int32_t i32) { + return vlqWrite(i32); +} + +uint32_t TDenseProtocol::subWriteString(const std::string& str) { + uint32_t size = str.size(); + uint32_t xfer = subWriteI32((int32_t)size); + if (size > 0) { + trans_->write((uint8_t*)str.data(), size); + } + return xfer + size; +} + + + +/* + * Reading functions + * + * These have a lot of the same logic as the writing functions, so if + * something is confusing, look for comments in the corresponding writer. + */ + +uint32_t TDenseProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + throw TApplicationException("TDenseProtocol doesn't work with messages (yet)."); + + uint32_t xfer = 0; + int32_t sz; + xfer += subReadI32(sz); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_2) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); + } + messageType = (TMessageType)(sz & 0x000000ff); + xfer += subReadString(name); + xfer += subReadI32(seqid); + } else { + throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); + } + return xfer; +} + +uint32_t TDenseProtocol::readMessageEnd() { + return 0; +} + +uint32_t TDenseProtocol::readStructBegin(string& name) { + uint32_t xfer = 0; + + if (ts_stack_.empty()) { + assert(standalone_); + + if (type_spec_ == NULL) { + resetState(); + throw TApplicationException("TDenseProtocol: No type specified."); + } else { + assert(type_spec_->ttype == T_STRUCT); + ts_stack_.push_back(type_spec_); + + // Check the fingerprint prefix. + uint8_t buf[FP_PREFIX_LEN]; + xfer += trans_->read(buf, FP_PREFIX_LEN); + if (std::memcmp(buf, type_spec_->fp_prefix, FP_PREFIX_LEN) != 0) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "Fingerprint in data does not match type_spec."); + } + } + } + + // We need a new field index for this structure. + idx_stack_.push_back(0); + return 0; +} + +uint32_t TDenseProtocol::readStructEnd() { + idx_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readFieldBegin(string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t xfer = 0; + + // For optional fields, check to see if they are there. + while (FMT.is_optional) { + bool is_present; + xfer += subReadBool(is_present); + if (is_present) { + break; + } + IDX++; + } + + // Once we hit a mandatory field, or an optional field that is present, + // we know that FMT and FTS point to the appropriate field. + + fieldId = FMT.tag; + fieldType = FTS->ttype; + + // Normally, we push the TypeSpec that we are about to read, + // but no reading is done for T_STOP. + if (FTS->ttype != T_STOP) { + ts_stack_.push_back(FTS); + } + return xfer; +} + +uint32_t TDenseProtocol::readFieldEnd() { + IDX++; + return 0; +} + +uint32_t TDenseProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + checkTType(T_MAP); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + keyType = ST1->ttype; + valType = ST2->ttype; + + ts_stack_.push_back(ST1); + mkv_stack_.push_back(true); + + return xfer; +} + +uint32_t TDenseProtocol::readMapEnd() { + ts_stack_.pop_back(); + mkv_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readListBegin(TType& elemType, + uint32_t& size) { + checkTType(T_LIST); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + elemType = ST1->ttype; + + ts_stack_.push_back(ST1); + + return xfer; +} + +uint32_t TDenseProtocol::readListEnd() { + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + checkTType(T_SET); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + elemType = ST1->ttype; + + ts_stack_.push_back(ST1); + + return xfer; +} + +uint32_t TDenseProtocol::readSetEnd() { + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readBool(bool& value) { + checkTType(T_BOOL); + stateTransition(); + return TBinaryProtocol::readBool(value); +} + +uint32_t TDenseProtocol::readByte(int8_t& byte) { + checkTType(T_BYTE); + stateTransition(); + return TBinaryProtocol::readByte(byte); +} + +uint32_t TDenseProtocol::readI16(int16_t& i16) { + checkTType(T_I16); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT16_MAX || val < INT16_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i16 out of range."); + } + i16 = (int16_t)val; + return rv; +} + +uint32_t TDenseProtocol::readI32(int32_t& i32) { + checkTType(T_I32); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i32 out of range."); + } + i32 = (int32_t)val; + return rv; +} + +uint32_t TDenseProtocol::readI64(int64_t& i64) { + checkTType(T_I64); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT64_MAX || val < INT64_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i64 out of range."); + } + i64 = (int64_t)val; + return rv; +} + +uint32_t TDenseProtocol::readDouble(double& dub) { + checkTType(T_DOUBLE); + stateTransition(); + return TBinaryProtocol::readDouble(dub); +} + +uint32_t TDenseProtocol::readString(std::string& str) { + checkTType(T_STRING); + stateTransition(); + return subReadString(str); +} + +uint32_t TDenseProtocol::readBinary(std::string& str) { + return TDenseProtocol::readString(str); +} + +uint32_t TDenseProtocol::subReadI32(int32_t& i32) { + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i32 out of range."); + } + i32 = (int32_t)val; + return rv; +} + +uint32_t TDenseProtocol::subReadString(std::string& str) { + uint32_t xfer; + int32_t size; + xfer = subReadI32(size); + return xfer + readStringBody(str, size); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TDenseProtocol.h b/lib/cpp/src/protocol/TDenseProtocol.h new file mode 100644 index 000000000..7655a479a --- /dev/null +++ b/lib/cpp/src/protocol/TDenseProtocol.h @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ 1 + +#include "TBinaryProtocol.h" + +namespace apache { namespace thrift { namespace protocol { + +/** + * !!!WARNING!!! + * This class is still highly experimental. Incompatible changes + * WILL be made to it without notice. DO NOT USE IT YET unless + * you are coordinating your testing with the author. + * + * The dense protocol is designed to use as little space as possible. + * + * There are two types of dense protocol instances. Standalone instances + * are not used for RPC and just encoded and decode structures of + * a predetermined type. Non-standalone instances are used for RPC. + * Currently, only standalone instances exist. + * + * To use a standalone dense protocol object, you must set the type_spec + * property (either in the constructor, or with setTypeSpec) to the local + * reflection TypeSpec of the structures you will write to (or read from) the + * protocol instance. + * + * BEST PRACTICES: + * - Never use optional for primitives or containers. + * - Only use optional for structures if they are very big and very rarely set. + * - All integers are variable-length, so you can use i64 without bloating. + * - NEVER EVER change the struct definitions IN ANY WAY without either + * changing your cache keys or talking to dreiss. + * + * TODO(dreiss): New class write with old meta. + * + * We override all of TBinaryProtocol's methods. + * We inherit so that we can can explicitly call TBPs's primitive-writing + * methods within our versions. + * + */ +class TDenseProtocol : public TBinaryProtocol { + protected: + static const int32_t VERSION_MASK = 0xffff0000; + // VERSION_1 (0x80010000) is taken by TBinaryProtocol. + static const int32_t VERSION_2 = 0x80020000; + + public: + typedef apache::thrift::reflection::local::TypeSpec TypeSpec; + static const int FP_PREFIX_LEN; + + /** + * @param tran The transport to use. + * @param type_spec The TypeSpec of the structures using this protocol. + */ + TDenseProtocol(boost::shared_ptr<TTransport> trans, + TypeSpec* type_spec = NULL) : + TBinaryProtocol(trans), + type_spec_(type_spec), + standalone_(true) + {} + + void setTypeSpec(TypeSpec* type_spec) { + type_spec_ = type_spec; + } + TypeSpec* getTypeSpec() { + return type_spec_; + } + + + /* + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + virtual uint32_t writeStructBegin(const char* name); + + virtual uint32_t writeStructEnd(); + + virtual uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + virtual uint32_t writeFieldEnd(); + + virtual uint32_t writeFieldStop(); + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + virtual uint32_t writeMapEnd(); + + virtual uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeListEnd(); + + virtual uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeSetEnd(); + + virtual uint32_t writeBool(const bool value); + + virtual uint32_t writeByte(const int8_t byte); + + virtual uint32_t writeI16(const int16_t i16); + + virtual uint32_t writeI32(const int32_t i32); + + virtual uint32_t writeI64(const int64_t i64); + + virtual uint32_t writeDouble(const double dub); + + virtual uint32_t writeString(const std::string& str); + + virtual uint32_t writeBinary(const std::string& str); + + + /* + * Helper writing functions (don't do state transitions). + */ + inline uint32_t subWriteI32(const int32_t i32); + + inline uint32_t subWriteString(const std::string& str); + + uint32_t subWriteBool(const bool value) { + return TBinaryProtocol::writeBool(value); + } + + + /* + * Reading functions + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + /* + * Helper reading functions (don't do state transitions). + */ + inline uint32_t subReadI32(int32_t& i32); + + inline uint32_t subReadString(std::string& str); + + uint32_t subReadBool(bool& value) { + return TBinaryProtocol::readBool(value); + } + + + private: + + // Implementation functions, documented in the .cpp. + inline void checkTType(const TType ttype); + inline void stateTransition(); + + // Read and write variable-length integers. + // Uses the same technique as the MIDI file format. + inline uint32_t vlqRead(uint64_t& vlq); + inline uint32_t vlqWrite(uint64_t vlq); + + // Called before throwing an exception to make the object reusable. + void resetState() { + ts_stack_.clear(); + idx_stack_.clear(); + mkv_stack_.clear(); + } + + // TypeSpec of the top-level structure to write, + // for standalone protocol objects. + TypeSpec* type_spec_; + + std::vector<TypeSpec*> ts_stack_; // TypeSpec stack. + std::vector<int> idx_stack_; // InDeX stack. + std::vector<bool> mkv_stack_; // Map Key/Vlue stack. + // True = key, False = value. + + // True iff this is a standalone instance (no RPC). + bool standalone_; +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TJSONProtocol.cpp b/lib/cpp/src/protocol/TJSONProtocol.cpp new file mode 100644 index 000000000..2a9c8f0b2 --- /dev/null +++ b/lib/cpp/src/protocol/TJSONProtocol.cpp @@ -0,0 +1,998 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TJSONProtocol.h" + +#include <math.h> +#include <boost/lexical_cast.hpp> +#include "TBase64Utils.h" +#include <transport/TTransportException.h> + +using namespace apache::thrift::transport; + +namespace apache { namespace thrift { namespace protocol { + + +// Static data + +static const uint8_t kJSONObjectStart = '{'; +static const uint8_t kJSONObjectEnd = '}'; +static const uint8_t kJSONArrayStart = '['; +static const uint8_t kJSONArrayEnd = ']'; +static const uint8_t kJSONNewline = '\n'; +static const uint8_t kJSONPairSeparator = ':'; +static const uint8_t kJSONElemSeparator = ','; +static const uint8_t kJSONBackslash = '\\'; +static const uint8_t kJSONStringDelimiter = '"'; +static const uint8_t kJSONZeroChar = '0'; +static const uint8_t kJSONEscapeChar = 'u'; + +static const std::string kJSONEscapePrefix("\\u00"); + +static const uint32_t kThriftVersion1 = 1; + +static const std::string kThriftNan("NaN"); +static const std::string kThriftInfinity("Infinity"); +static const std::string kThriftNegativeInfinity("-Infinity"); + +static const std::string kTypeNameBool("tf"); +static const std::string kTypeNameByte("i8"); +static const std::string kTypeNameI16("i16"); +static const std::string kTypeNameI32("i32"); +static const std::string kTypeNameI64("i64"); +static const std::string kTypeNameDouble("dbl"); +static const std::string kTypeNameStruct("rec"); +static const std::string kTypeNameString("str"); +static const std::string kTypeNameMap("map"); +static const std::string kTypeNameList("lst"); +static const std::string kTypeNameSet("set"); + +static const std::string &getTypeNameForTypeID(TType typeID) { + switch (typeID) { + case T_BOOL: + return kTypeNameBool; + case T_BYTE: + return kTypeNameByte; + case T_I16: + return kTypeNameI16; + case T_I32: + return kTypeNameI32; + case T_I64: + return kTypeNameI64; + case T_DOUBLE: + return kTypeNameDouble; + case T_STRING: + return kTypeNameString; + case T_STRUCT: + return kTypeNameStruct; + case T_MAP: + return kTypeNameMap; + case T_SET: + return kTypeNameSet; + case T_LIST: + return kTypeNameList; + default: + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + "Unrecognized type"); + } +} + +static TType getTypeIDForTypeName(const std::string &name) { + TType result = T_STOP; // Sentinel value + if (name.length() > 1) { + switch (name[0]) { + case 'd': + result = T_DOUBLE; + break; + case 'i': + switch (name[1]) { + case '8': + result = T_BYTE; + break; + case '1': + result = T_I16; + break; + case '3': + result = T_I32; + break; + case '6': + result = T_I64; + break; + } + break; + case 'l': + result = T_LIST; + break; + case 'm': + result = T_MAP; + break; + case 'r': + result = T_STRUCT; + break; + case 's': + if (name[1] == 't') { + result = T_STRING; + } + else if (name[1] == 'e') { + result = T_SET; + } + break; + case 't': + result = T_BOOL; + break; + } + } + if (result == T_STOP) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + "Unrecognized type"); + } + return result; +} + + +// This table describes the handling for the first 0x30 characters +// 0 : escape using "\u00xx" notation +// 1 : just output index +// <other> : escape using "\<other>" notation +static const uint8_t kJSONCharTable[0x30] = { +// 0 1 2 3 4 5 6 7 8 9 A B C D E F + 0, 0, 0, 0, 0, 0, 0, 0,'b','t','n', 0,'f','r', 0, 0, // 0 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1 + 1, 1,'"', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 +}; + + +// This string's characters must match up with the elements in kEscapeCharVals. +// I don't have '/' on this list even though it appears on www.json.org -- +// it is not in the RFC +const static std::string kEscapeChars("\"\\bfnrt"); + +// The elements of this array must match up with the sequence of characters in +// kEscapeChars +const static uint8_t kEscapeCharVals[7] = { + '"', '\\', '\b', '\f', '\n', '\r', '\t', +}; + + +// Static helper functions + +// Read 1 character from the transport trans and verify that it is the +// expected character ch. +// Throw a protocol exception if it is not. +static uint32_t readSyntaxChar(TJSONProtocol::LookaheadReader &reader, + uint8_t ch) { + uint8_t ch2 = reader.read(); + if (ch2 != ch) { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected \'" + std::string((char *)&ch, 1) + + "\'; got \'" + std::string((char *)&ch2, 1) + + "\'."); + } + return 1; +} + +// Return the integer value of a hex character ch. +// Throw a protocol exception if the character is not [0-9a-f]. +static uint8_t hexVal(uint8_t ch) { + if ((ch >= '0') && (ch <= '9')) { + return ch - '0'; + } + else if ((ch >= 'a') && (ch <= 'f')) { + return ch - 'a'; + } + else { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected hex val ([0-9a-f]); got \'" + + std::string((char *)&ch, 1) + "\'."); + } +} + +// Return the hex character representing the integer val. The value is masked +// to make sure it is in the correct range. +static uint8_t hexChar(uint8_t val) { + val &= 0x0F; + if (val < 10) { + return val + '0'; + } + else { + return val + 'a'; + } +} + +// Return true if the character ch is in [-+0-9.Ee]; false otherwise +static bool isJSONNumeric(uint8_t ch) { + switch (ch) { + case '+': + case '-': + case '.': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case 'E': + case 'e': + return true; + } + return false; +} + + +/** + * Class to serve as base JSON context and as base class for other context + * implementations + */ +class TJSONContext { + + public: + + TJSONContext() {}; + + virtual ~TJSONContext() {}; + + /** + * Write context data to the transport. Default is to do nothing. + */ + virtual uint32_t write(TTransport &trans) { + return 0; + }; + + /** + * Read context data from the transport. Default is to do nothing. + */ + virtual uint32_t read(TJSONProtocol::LookaheadReader &reader) { + return 0; + }; + + /** + * Return true if numbers need to be escaped as strings in this context. + * Default behavior is to return false. + */ + virtual bool escapeNum() { + return false; + } +}; + +// Context class for object member key-value pairs +class JSONPairContext : public TJSONContext { + +public: + + JSONPairContext() : + first_(true), + colon_(true) { + } + + uint32_t write(TTransport &trans) { + if (first_) { + first_ = false; + colon_ = true; + return 0; + } + else { + trans.write(colon_ ? &kJSONPairSeparator : &kJSONElemSeparator, 1); + colon_ = !colon_; + return 1; + } + } + + uint32_t read(TJSONProtocol::LookaheadReader &reader) { + if (first_) { + first_ = false; + colon_ = true; + return 0; + } + else { + uint8_t ch = (colon_ ? kJSONPairSeparator : kJSONElemSeparator); + colon_ = !colon_; + return readSyntaxChar(reader, ch); + } + } + + // Numbers must be turned into strings if they are the key part of a pair + virtual bool escapeNum() { + return colon_; + } + + private: + + bool first_; + bool colon_; +}; + +// Context class for lists +class JSONListContext : public TJSONContext { + +public: + + JSONListContext() : + first_(true) { + } + + uint32_t write(TTransport &trans) { + if (first_) { + first_ = false; + return 0; + } + else { + trans.write(&kJSONElemSeparator, 1); + return 1; + } + } + + uint32_t read(TJSONProtocol::LookaheadReader &reader) { + if (first_) { + first_ = false; + return 0; + } + else { + return readSyntaxChar(reader, kJSONElemSeparator); + } + } + + private: + bool first_; +}; + + +TJSONProtocol::TJSONProtocol(boost::shared_ptr<TTransport> ptrans) : + TProtocol(ptrans), + context_(new TJSONContext()), + reader_(*ptrans) { +} + +TJSONProtocol::~TJSONProtocol() {} + +void TJSONProtocol::pushContext(boost::shared_ptr<TJSONContext> c) { + contexts_.push(context_); + context_ = c; +} + +void TJSONProtocol::popContext() { + context_ = contexts_.top(); + contexts_.pop(); +} + +// Write the character ch as a JSON escape sequence ("\u00xx") +uint32_t TJSONProtocol::writeJSONEscapeChar(uint8_t ch) { + trans_->write((const uint8_t *)kJSONEscapePrefix.c_str(), + kJSONEscapePrefix.length()); + uint8_t outCh = hexChar(ch >> 4); + trans_->write(&outCh, 1); + outCh = hexChar(ch); + trans_->write(&outCh, 1); + return 6; +} + +// Write the character ch as part of a JSON string, escaping as appropriate. +uint32_t TJSONProtocol::writeJSONChar(uint8_t ch) { + if (ch >= 0x30) { + if (ch == kJSONBackslash) { // Only special character >= 0x30 is '\' + trans_->write(&kJSONBackslash, 1); + trans_->write(&kJSONBackslash, 1); + return 2; + } + else { + trans_->write(&ch, 1); + return 1; + } + } + else { + uint8_t outCh = kJSONCharTable[ch]; + // Check if regular character, backslash escaped, or JSON escaped + if (outCh == 1) { + trans_->write(&ch, 1); + return 1; + } + else if (outCh > 1) { + trans_->write(&kJSONBackslash, 1); + trans_->write(&outCh, 1); + return 2; + } + else { + return writeJSONEscapeChar(ch); + } + } +} + +// Write out the contents of the string str as a JSON string, escaping +// characters as appropriate. +uint32_t TJSONProtocol::writeJSONString(const std::string &str) { + uint32_t result = context_->write(*trans_); + result += 2; // For quotes + trans_->write(&kJSONStringDelimiter, 1); + std::string::const_iterator iter(str.begin()); + std::string::const_iterator end(str.end()); + while (iter != end) { + result += writeJSONChar(*iter++); + } + trans_->write(&kJSONStringDelimiter, 1); + return result; +} + +// Write out the contents of the string as JSON string, base64-encoding +// the string's contents, and escaping as appropriate +uint32_t TJSONProtocol::writeJSONBase64(const std::string &str) { + uint32_t result = context_->write(*trans_); + result += 2; // For quotes + trans_->write(&kJSONStringDelimiter, 1); + uint8_t b[4]; + const uint8_t *bytes = (const uint8_t *)str.c_str(); + uint32_t len = str.length(); + while (len >= 3) { + // Encode 3 bytes at a time + base64_encode(bytes, 3, b); + trans_->write(b, 4); + result += 4; + bytes += 3; + len -=3; + } + if (len) { // Handle remainder + base64_encode(bytes, len, b); + trans_->write(b, len + 1); + result += len + 1; + } + trans_->write(&kJSONStringDelimiter, 1); + return result; +} + +// Convert the given integer type to a JSON number, or a string +// if the context requires it (eg: key in a map pair). +template <typename NumberType> +uint32_t TJSONProtocol::writeJSONInteger(NumberType num) { + uint32_t result = context_->write(*trans_); + std::string val(boost::lexical_cast<std::string>(num)); + bool escapeNum = context_->escapeNum(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + trans_->write((const uint8_t *)val.c_str(), val.length()); + result += val.length(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + return result; +} + +// Convert the given double to a JSON string, which is either the number, +// "NaN" or "Infinity" or "-Infinity". +uint32_t TJSONProtocol::writeJSONDouble(double num) { + uint32_t result = context_->write(*trans_); + std::string val(boost::lexical_cast<std::string>(num)); + + // Normalize output of boost::lexical_cast for NaNs and Infinities + bool special = false; + switch (val[0]) { + case 'N': + case 'n': + val = kThriftNan; + special = true; + break; + case 'I': + case 'i': + val = kThriftInfinity; + special = true; + break; + case '-': + if ((val[1] == 'I') || (val[1] == 'i')) { + val = kThriftNegativeInfinity; + special = true; + } + break; + } + + bool escapeNum = special || context_->escapeNum(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + trans_->write((const uint8_t *)val.c_str(), val.length()); + result += val.length(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + return result; +} + +uint32_t TJSONProtocol::writeJSONObjectStart() { + uint32_t result = context_->write(*trans_); + trans_->write(&kJSONObjectStart, 1); + pushContext(boost::shared_ptr<TJSONContext>(new JSONPairContext())); + return result + 1; +} + +uint32_t TJSONProtocol::writeJSONObjectEnd() { + popContext(); + trans_->write(&kJSONObjectEnd, 1); + return 1; +} + +uint32_t TJSONProtocol::writeJSONArrayStart() { + uint32_t result = context_->write(*trans_); + trans_->write(&kJSONArrayStart, 1); + pushContext(boost::shared_ptr<TJSONContext>(new JSONListContext())); + return result + 1; +} + +uint32_t TJSONProtocol::writeJSONArrayEnd() { + popContext(); + trans_->write(&kJSONArrayEnd, 1); + return 1; +} + +uint32_t TJSONProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONInteger(kThriftVersion1); + result += writeJSONString(name); + result += writeJSONInteger(messageType); + result += writeJSONInteger(seqid); + return result; +} + +uint32_t TJSONProtocol::writeMessageEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeStructBegin(const char* name) { + return writeJSONObjectStart(); +} + +uint32_t TJSONProtocol::writeStructEnd() { + return writeJSONObjectEnd(); +} + +uint32_t TJSONProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t result = writeJSONInteger(fieldId); + result += writeJSONObjectStart(); + result += writeJSONString(getTypeNameForTypeID(fieldType)); + return result; +} + +uint32_t TJSONProtocol::writeFieldEnd() { + return writeJSONObjectEnd(); +} + +uint32_t TJSONProtocol::writeFieldStop() { + return 0; +} + +uint32_t TJSONProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(keyType)); + result += writeJSONString(getTypeNameForTypeID(valType)); + result += writeJSONInteger((int64_t)size); + result += writeJSONObjectStart(); + return result; +} + +uint32_t TJSONProtocol::writeMapEnd() { + return writeJSONObjectEnd() + writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(elemType)); + result += writeJSONInteger((int64_t)size); + return result; +} + +uint32_t TJSONProtocol::writeListEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(elemType)); + result += writeJSONInteger((int64_t)size); + return result; +} + +uint32_t TJSONProtocol::writeSetEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeBool(const bool value) { + return writeJSONInteger(value); +} + +uint32_t TJSONProtocol::writeByte(const int8_t byte) { + // writeByte() must be handled specially becuase boost::lexical cast sees + // int8_t as a text type instead of an integer type + return writeJSONInteger((int16_t)byte); +} + +uint32_t TJSONProtocol::writeI16(const int16_t i16) { + return writeJSONInteger(i16); +} + +uint32_t TJSONProtocol::writeI32(const int32_t i32) { + return writeJSONInteger(i32); +} + +uint32_t TJSONProtocol::writeI64(const int64_t i64) { + return writeJSONInteger(i64); +} + +uint32_t TJSONProtocol::writeDouble(const double dub) { + return writeJSONDouble(dub); +} + +uint32_t TJSONProtocol::writeString(const std::string& str) { + return writeJSONString(str); +} + +uint32_t TJSONProtocol::writeBinary(const std::string& str) { + return writeJSONBase64(str); +} + + /** + * Reading functions + */ + +// Reads 1 byte and verifies that it matches ch. +uint32_t TJSONProtocol::readJSONSyntaxChar(uint8_t ch) { + return readSyntaxChar(reader_, ch); +} + +// Decodes the four hex parts of a JSON escaped string character and returns +// the character via out. The first two characters must be "00". +uint32_t TJSONProtocol::readJSONEscapeChar(uint8_t *out) { + uint8_t b[2]; + readJSONSyntaxChar(kJSONZeroChar); + readJSONSyntaxChar(kJSONZeroChar); + b[0] = reader_.read(); + b[1] = reader_.read(); + *out = (hexVal(b[0]) << 4) + hexVal(b[1]); + return 4; +} + +// Decodes a JSON string, including unescaping, and returns the string via str +uint32_t TJSONProtocol::readJSONString(std::string &str, bool skipContext) { + uint32_t result = (skipContext ? 0 : context_->read(reader_)); + result += readJSONSyntaxChar(kJSONStringDelimiter); + uint8_t ch; + str.clear(); + while (true) { + ch = reader_.read(); + ++result; + if (ch == kJSONStringDelimiter) { + break; + } + if (ch == kJSONBackslash) { + ch = reader_.read(); + ++result; + if (ch == kJSONEscapeChar) { + result += readJSONEscapeChar(&ch); + } + else { + size_t pos = kEscapeChars.find(ch); + if (pos == std::string::npos) { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected control char, got '" + + std::string((const char *)&ch, 1) + "'."); + } + ch = kEscapeCharVals[pos]; + } + } + str += ch; + } + return result; +} + +// Reads a block of base64 characters, decoding it, and returns via str +uint32_t TJSONProtocol::readJSONBase64(std::string &str) { + std::string tmp; + uint32_t result = readJSONString(tmp); + uint8_t *b = (uint8_t *)tmp.c_str(); + uint32_t len = tmp.length(); + str.clear(); + while (len >= 4) { + base64_decode(b, 4); + str.append((const char *)b, 3); + b += 4; + len -= 4; + } + // Don't decode if we hit the end or got a single leftover byte (invalid + // base64 but legal for skip of regular string type) + if (len > 1) { + base64_decode(b, len); + str.append((const char *)b, len - 1); + } + return result; +} + +// Reads a sequence of characters, stopping at the first one that is not +// a valid JSON numeric character. +uint32_t TJSONProtocol::readJSONNumericChars(std::string &str) { + uint32_t result = 0; + str.clear(); + while (true) { + uint8_t ch = reader_.peek(); + if (!isJSONNumeric(ch)) { + break; + } + reader_.read(); + str += ch; + ++result; + } + return result; +} + +// Reads a sequence of characters and assembles them into a number, +// returning them via num +template <typename NumberType> +uint32_t TJSONProtocol::readJSONInteger(NumberType &num) { + uint32_t result = context_->read(reader_); + if (context_->escapeNum()) { + result += readJSONSyntaxChar(kJSONStringDelimiter); + } + std::string str; + result += readJSONNumericChars(str); + try { + num = boost::lexical_cast<NumberType>(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + if (context_->escapeNum()) { + result += readJSONSyntaxChar(kJSONStringDelimiter); + } + return result; +} + +// Reads a JSON number or string and interprets it as a double. +uint32_t TJSONProtocol::readJSONDouble(double &num) { + uint32_t result = context_->read(reader_); + std::string str; + if (reader_.peek() == kJSONStringDelimiter) { + result += readJSONString(str, true); + // Check for NaN, Infinity and -Infinity + if (str == kThriftNan) { + num = HUGE_VAL/HUGE_VAL; // generates NaN + } + else if (str == kThriftInfinity) { + num = HUGE_VAL; + } + else if (str == kThriftNegativeInfinity) { + num = -HUGE_VAL; + } + else { + if (!context_->escapeNum()) { + // Throw exception -- we should not be in a string in this case + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Numeric data unexpectedly quoted"); + } + try { + num = boost::lexical_cast<double>(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + } + } + else { + if (context_->escapeNum()) { + // This will throw - we should have had a quote if escapeNum == true + readJSONSyntaxChar(kJSONStringDelimiter); + } + result += readJSONNumericChars(str); + try { + num = boost::lexical_cast<double>(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + } + return result; +} + +uint32_t TJSONProtocol::readJSONObjectStart() { + uint32_t result = context_->read(reader_); + result += readJSONSyntaxChar(kJSONObjectStart); + pushContext(boost::shared_ptr<TJSONContext>(new JSONPairContext())); + return result; +} + +uint32_t TJSONProtocol::readJSONObjectEnd() { + uint32_t result = readJSONSyntaxChar(kJSONObjectEnd); + popContext(); + return result; +} + +uint32_t TJSONProtocol::readJSONArrayStart() { + uint32_t result = context_->read(reader_); + result += readJSONSyntaxChar(kJSONArrayStart); + pushContext(boost::shared_ptr<TJSONContext>(new JSONListContext())); + return result; +} + +uint32_t TJSONProtocol::readJSONArrayEnd() { + uint32_t result = readJSONSyntaxChar(kJSONArrayEnd); + popContext(); + return result; +} + +uint32_t TJSONProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t result = readJSONArrayStart(); + uint64_t tmpVal = 0; + result += readJSONInteger(tmpVal); + if (tmpVal != kThriftVersion1) { + throw TProtocolException(TProtocolException::BAD_VERSION, + "Message contained bad version."); + } + result += readJSONString(name); + result += readJSONInteger(tmpVal); + messageType = (TMessageType)tmpVal; + result += readJSONInteger(tmpVal); + seqid = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readMessageEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readStructBegin(std::string& name) { + return readJSONObjectStart(); +} + +uint32_t TJSONProtocol::readStructEnd() { + return readJSONObjectEnd(); +} + +uint32_t TJSONProtocol::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t result = 0; + // Check if we hit the end of the list + uint8_t ch = reader_.peek(); + if (ch == kJSONObjectEnd) { + fieldType = apache::thrift::protocol::T_STOP; + } + else { + uint64_t tmpVal = 0; + std::string tmpStr; + result += readJSONInteger(tmpVal); + fieldId = tmpVal; + result += readJSONObjectStart(); + result += readJSONString(tmpStr); + fieldType = getTypeIDForTypeName(tmpStr); + } + return result; +} + +uint32_t TJSONProtocol::readFieldEnd() { + return readJSONObjectEnd(); +} + +uint32_t TJSONProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + keyType = getTypeIDForTypeName(tmpStr); + result += readJSONString(tmpStr); + valType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + result += readJSONObjectStart(); + return result; +} + +uint32_t TJSONProtocol::readMapEnd() { + return readJSONObjectEnd() + readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readListBegin(TType& elemType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + elemType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readListEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + elemType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readSetEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readBool(bool& value) { + return readJSONInteger(value); +} + +// readByte() must be handled properly becuase boost::lexical cast sees int8_t +// as a text type instead of an integer type +uint32_t TJSONProtocol::readByte(int8_t& byte) { + int16_t tmp = (int16_t) byte; + uint32_t result = readJSONInteger(tmp); + assert(tmp < 256); + byte = (int8_t)tmp; + return result; +} + +uint32_t TJSONProtocol::readI16(int16_t& i16) { + return readJSONInteger(i16); +} + +uint32_t TJSONProtocol::readI32(int32_t& i32) { + return readJSONInteger(i32); +} + +uint32_t TJSONProtocol::readI64(int64_t& i64) { + return readJSONInteger(i64); +} + +uint32_t TJSONProtocol::readDouble(double& dub) { + return readJSONDouble(dub); +} + +uint32_t TJSONProtocol::readString(std::string &str) { + return readJSONString(str); +} + +uint32_t TJSONProtocol::readBinary(std::string &str) { + return readJSONBase64(str); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TJSONProtocol.h b/lib/cpp/src/protocol/TJSONProtocol.h new file mode 100644 index 000000000..2df499ac0 --- /dev/null +++ b/lib/cpp/src/protocol/TJSONProtocol.h @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include <stack> + +namespace apache { namespace thrift { namespace protocol { + +// Forward declaration +class TJSONContext; + +/** + * JSON protocol for Thrift. + * + * Implements a protocol which uses JSON as the wire-format. + * + * Thrift types are represented as described below: + * + * 1. Every Thrift integer type is represented as a JSON number. + * + * 2. Thrift doubles are represented as JSON numbers. Some special values are + * represented as strings: + * a. "NaN" for not-a-number values + * b. "Infinity" for postive infinity + * c. "-Infinity" for negative infinity + * + * 3. Thrift string values are emitted as JSON strings, with appropriate + * escaping. + * + * 4. Thrift binary values are encoded into Base64 and emitted as JSON strings. + * The readBinary() method is written such that it will properly skip if + * called on a Thrift string (although it will decode garbage data). + * + * 5. Thrift structs are represented as JSON objects, with the field ID as the + * key, and the field value represented as a JSON object with a single + * key-value pair. The key is a short string identifier for that type, + * followed by the value. The valid type identifiers are: "tf" for bool, + * "i8" for byte, "i16" for 16-bit integer, "i32" for 32-bit integer, "i64" + * for 64-bit integer, "dbl" for double-precision loating point, "str" for + * string (including binary), "rec" for struct ("records"), "map" for map, + * "lst" for list, "set" for set. + * + * 6. Thrift lists and sets are represented as JSON arrays, with the first + * element of the JSON array being the string identifier for the Thrift + * element type and the second element of the JSON array being the count of + * the Thrift elements. The Thrift elements then follow. + * + * 7. Thrift maps are represented as JSON arrays, with the first two elements + * of the JSON array being the string identifiers for the Thrift key type + * and value type, followed by the count of the Thrift pairs, followed by a + * JSON object containing the key-value pairs. Note that JSON keys can only + * be strings, which means that the key type of the Thrift map should be + * restricted to numeric or string types -- in the case of numerics, they + * are serialized as strings. + * + * 8. Thrift messages are represented as JSON arrays, with the protocol + * version #, the message name, the message type, and the sequence ID as + * the first 4 elements. + * + * More discussion of the double handling is probably warranted. The aim of + * the current implementation is to match as closely as possible the behavior + * of Java's Double.toString(), which has no precision loss. Implementors in + * other languages should strive to achieve that where possible. I have not + * yet verified whether boost:lexical_cast, which is doing that work for me in + * C++, loses any precision, but I am leaving this as a future improvement. I + * may try to provide a C component for this, so that other languages could + * bind to the same underlying implementation for maximum consistency. + * + * Note further that JavaScript itself is not capable of representing + * floating point infinities -- presumably when we have a JavaScript Thrift + * client, this would mean that infinities get converted to not-a-number in + * transmission. I don't know of any work-around for this issue. + * + */ +class TJSONProtocol : public TProtocol { + public: + + TJSONProtocol(boost::shared_ptr<TTransport> ptrans); + + ~TJSONProtocol(); + + private: + + void pushContext(boost::shared_ptr<TJSONContext> c); + + void popContext(); + + uint32_t writeJSONEscapeChar(uint8_t ch); + + uint32_t writeJSONChar(uint8_t ch); + + uint32_t writeJSONString(const std::string &str); + + uint32_t writeJSONBase64(const std::string &str); + + template <typename NumberType> + uint32_t writeJSONInteger(NumberType num); + + uint32_t writeJSONDouble(double num); + + uint32_t writeJSONObjectStart() ; + + uint32_t writeJSONObjectEnd(); + + uint32_t writeJSONArrayStart(); + + uint32_t writeJSONArrayEnd(); + + uint32_t readJSONSyntaxChar(uint8_t ch); + + uint32_t readJSONEscapeChar(uint8_t *out); + + uint32_t readJSONString(std::string &str, bool skipContext = false); + + uint32_t readJSONBase64(std::string &str); + + uint32_t readJSONNumericChars(std::string &str); + + template <typename NumberType> + uint32_t readJSONInteger(NumberType &num); + + uint32_t readJSONDouble(double &num); + + uint32_t readJSONObjectStart(); + + uint32_t readJSONObjectEnd(); + + uint32_t readJSONArrayStart(); + + uint32_t readJSONArrayEnd(); + + public: + + /** + * Writing functions. + */ + + uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + uint32_t writeMessageEnd(); + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * Reading functions + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + class LookaheadReader { + + public: + + LookaheadReader(TTransport &trans) : + trans_(&trans), + hasData_(false) { + } + + uint8_t read() { + if (hasData_) { + hasData_ = false; + } + else { + trans_->readAll(&data_, 1); + } + return data_; + } + + uint8_t peek() { + if (!hasData_) { + trans_->readAll(&data_, 1); + } + hasData_ = true; + return data_; + } + + private: + TTransport *trans_; + bool hasData_; + uint8_t data_; + }; + + private: + + std::stack<boost::shared_ptr<TJSONContext> > contexts_; + boost::shared_ptr<TJSONContext> context_; + LookaheadReader reader_; +}; + +/** + * Constructs input and output protocol objects given transports. + */ +class TJSONProtocolFactory : public TProtocolFactory { + public: + TJSONProtocolFactory() {} + + virtual ~TJSONProtocolFactory() {} + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TJSONProtocol(trans)); + } +}; + +}}} // apache::thrift::protocol + + +// TODO(dreiss): Move part of ThriftJSONString into a .cpp file and remove this. +#include <transport/TBufferTransports.h> + +namespace apache { namespace thrift { + +template<typename ThriftStruct> + std::string ThriftJSONString(const ThriftStruct& ts) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr<TTransport> trans(buffer); + TJSONProtocol protocol(trans); + + ts.write(&protocol); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} + +}} // apache::thrift + +#endif // #define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1 diff --git a/lib/cpp/src/protocol/TOneWayProtocol.h b/lib/cpp/src/protocol/TOneWayProtocol.h new file mode 100644 index 000000000..6f08fe1d7 --- /dev/null +++ b/lib/cpp/src/protocol/TOneWayProtocol.h @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_ 1 + +#include "TProtocol.h" + +namespace apache { namespace thrift { namespace protocol { + +/** + * Abstract class for implementing a protocol that can only be written, + * not read. + * + */ +class TWriteOnlyProtocol : public TProtocol { + public: + /** + * @param subclass_name The name of the concrete subclass. + */ + TWriteOnlyProtocol(boost::shared_ptr<TTransport> trans, + const std::string& subclass_name) + : TProtocol(trans) + , subclass_(subclass_name) + {} + + // All writing functions remain abstract. + + /** + * Reading functions all throw an exception. + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMessageEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readStructBegin(std::string& name) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readStructEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readFieldEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMapEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readListBegin(TType& elemType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readListEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readSetBegin(TType& elemType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readSetEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readBool(bool& value) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readByte(int8_t& byte) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI16(int16_t& i16) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI32(int32_t& i32) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI64(int64_t& i64) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readDouble(double& dub) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readString(std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readBinary(std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + private: + std::string subclass_; +}; + + +/** + * Abstract class for implementing a protocol that can only be read, + * not written. + * + */ +class TReadOnlyProtocol : public TProtocol { + public: + /** + * @param subclass_name The name of the concrete subclass. + */ + TReadOnlyProtocol(boost::shared_ptr<TTransport> trans, + const std::string& subclass_name) + : TProtocol(trans) + , subclass_(subclass_name) + {} + + // All reading functions remain abstract. + + /** + * Writing functions all throw an exception. + */ + + uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMessageEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + + uint32_t writeStructBegin(const char* name) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeStructEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldStop() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMapEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeListBegin(const TType elemType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeListEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeSetEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeBool(const bool value) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeByte(const int8_t byte) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI16(const int16_t i16) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI32(const int32_t i32) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI64(const int64_t i64) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeDouble(const double dub) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeString(const std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeBinary(const std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + private: + std::string subclass_; +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TProtocol.h b/lib/cpp/src/protocol/TProtocol.h new file mode 100644 index 000000000..40258277d --- /dev/null +++ b/lib/cpp/src/protocol/TProtocol.h @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1 + +#include <transport/TTransport.h> +#include <protocol/TProtocolException.h> + +#include <boost/shared_ptr.hpp> +#include <boost/static_assert.hpp> + +#include <netinet/in.h> +#include <sys/types.h> +#include <string> +#include <map> + + +// Use this to get around strict aliasing rules. +// For example, uint64_t i = bitwise_cast<uint64_t>(returns_double()); +// The most obvious implementation is to just cast a pointer, +// but that doesn't work. +// For a pretty in-depth explanation of the problem, see +// http://www.cellperformance.com/mike_acton/2006/06/ (...) +// understanding_strict_aliasing.html +template <typename To, typename From> +static inline To bitwise_cast(From from) { + BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To)); + + // BAD!!! These are all broken with -O2. + //return *reinterpret_cast<To*>(&from); // BAD!!! + //return *static_cast<To*>(static_cast<void*>(&from)); // BAD!!! + //return *(To*)(void*)&from; // BAD!!! + + // Super clean and paritally blessed by section 3.9 of the standard. + //unsigned char c[sizeof(from)]; + //memcpy(c, &from, sizeof(from)); + //To to; + //memcpy(&to, c, sizeof(c)); + //return to; + + // Slightly more questionable. + // Same code emitted by GCC. + //To to; + //memcpy(&to, &from, sizeof(from)); + //return to; + + // Technically undefined, but almost universally supported, + // and the most efficient implementation. + union { + From f; + To t; + } u; + u.f = from; + return u.t; +} + + +namespace apache { namespace thrift { namespace protocol { + +using apache::thrift::transport::TTransport; + +#ifdef HAVE_ENDIAN_H +#include <endian.h> +#endif + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +# define __BYTE_ORDER BYTE_ORDER +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __BIG_ENDIAN BIG_ENDIAN +# else +# error "Cannot determine endianness" +# endif +#endif + +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +# if defined(__GNUC__) && defined(__GLIBC__) +# include <byteswap.h> +# define htolell(n) bswap_64(n) +# define letohll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define bswap_64(n) \ + ( (((n) & 0xff00000000000000ull) >> 56) \ + | (((n) & 0x00ff000000000000ull) >> 40) \ + | (((n) & 0x0000ff0000000000ull) >> 24) \ + | (((n) & 0x000000ff00000000ull) >> 8) \ + | (((n) & 0x00000000ff000000ull) << 8) \ + | (((n) & 0x0000000000ff0000ull) << 24) \ + | (((n) & 0x000000000000ff00ull) << 40) \ + | (((n) & 0x00000000000000ffull) << 56) ) +# define ntolell(n) bswap_64(n) +# define letonll(n) bswap_64(n) +# endif /* GNUC & GLIBC */ +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# define htolell(n) (n) +# define letohll(n) (n) +# if defined(__GNUC__) && defined(__GLIBC__) +# include <byteswap.h> +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +/** + * Enumerated definition of the types that the Thrift protocol supports. + * Take special note of the T_END type which is used specifically to mark + * the end of a sequence of fields. + */ +enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +}; + +/** + * Enumerated definition of the message types that the Thrift protocol + * supports. + */ +enum TMessageType { + T_CALL = 1, + T_REPLY = 2, + T_EXCEPTION = 3, + T_ONEWAY = 4 +}; + +/** + * Abstract class for a thrift protocol driver. These are all the methods that + * a protocol must implement. Essentially, there must be some way of reading + * and writing all the base types, plus a mechanism for writing out structs + * with indexed fields. + * + * TProtocol objects should not be shared across multiple encoding contexts, + * as they may need to maintain internal state in some protocols (i.e. XML). + * Note that is is acceptable for the TProtocol module to do its own internal + * buffered reads/writes to the underlying TTransport where appropriate (i.e. + * when parsing an input XML stream, reading should be batched rather than + * looking ahead character by character for a close tag). + * + */ +class TProtocol { + public: + virtual ~TProtocol() {} + + /** + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) = 0; + + virtual uint32_t writeMessageEnd() = 0; + + + virtual uint32_t writeStructBegin(const char* name) = 0; + + virtual uint32_t writeStructEnd() = 0; + + virtual uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) = 0; + + virtual uint32_t writeFieldEnd() = 0; + + virtual uint32_t writeFieldStop() = 0; + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) = 0; + + virtual uint32_t writeMapEnd() = 0; + + virtual uint32_t writeListBegin(const TType elemType, + const uint32_t size) = 0; + + virtual uint32_t writeListEnd() = 0; + + virtual uint32_t writeSetBegin(const TType elemType, + const uint32_t size) = 0; + + virtual uint32_t writeSetEnd() = 0; + + virtual uint32_t writeBool(const bool value) = 0; + + virtual uint32_t writeByte(const int8_t byte) = 0; + + virtual uint32_t writeI16(const int16_t i16) = 0; + + virtual uint32_t writeI32(const int32_t i32) = 0; + + virtual uint32_t writeI64(const int64_t i64) = 0; + + virtual uint32_t writeDouble(const double dub) = 0; + + virtual uint32_t writeString(const std::string& str) = 0; + + virtual uint32_t writeBinary(const std::string& str) = 0; + + /** + * Reading functions + */ + + virtual uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) = 0; + + virtual uint32_t readMessageEnd() = 0; + + virtual uint32_t readStructBegin(std::string& name) = 0; + + virtual uint32_t readStructEnd() = 0; + + virtual uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) = 0; + + virtual uint32_t readFieldEnd() = 0; + + virtual uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) = 0; + + virtual uint32_t readMapEnd() = 0; + + virtual uint32_t readListBegin(TType& elemType, + uint32_t& size) = 0; + + virtual uint32_t readListEnd() = 0; + + virtual uint32_t readSetBegin(TType& elemType, + uint32_t& size) = 0; + + virtual uint32_t readSetEnd() = 0; + + virtual uint32_t readBool(bool& value) = 0; + + virtual uint32_t readByte(int8_t& byte) = 0; + + virtual uint32_t readI16(int16_t& i16) = 0; + + virtual uint32_t readI32(int32_t& i32) = 0; + + virtual uint32_t readI64(int64_t& i64) = 0; + + virtual uint32_t readDouble(double& dub) = 0; + + virtual uint32_t readString(std::string& str) = 0; + + virtual uint32_t readBinary(std::string& str) = 0; + + uint32_t readBool(std::vector<bool>::reference ref) { + bool value; + uint32_t rv = readBool(value); + ref = value; + return rv; + } + + /** + * Method to arbitrarily skip over data. + */ + uint32_t skip(TType type) { + switch (type) { + case T_BOOL: + { + bool boolv; + return readBool(boolv); + } + case T_BYTE: + { + int8_t bytev; + return readByte(bytev); + } + case T_I16: + { + int16_t i16; + return readI16(i16); + } + case T_I32: + { + int32_t i32; + return readI32(i32); + } + case T_I64: + { + int64_t i64; + return readI64(i64); + } + case T_DOUBLE: + { + double dub; + return readDouble(dub); + } + case T_STRING: + { + std::string str; + return readBinary(str); + } + case T_STRUCT: + { + uint32_t result = 0; + std::string name; + int16_t fid; + TType ftype; + result += readStructBegin(name); + while (true) { + result += readFieldBegin(name, ftype, fid); + if (ftype == T_STOP) { + break; + } + result += skip(ftype); + result += readFieldEnd(); + } + result += readStructEnd(); + return result; + } + case T_MAP: + { + uint32_t result = 0; + TType keyType; + TType valType; + uint32_t i, size; + result += readMapBegin(keyType, valType, size); + for (i = 0; i < size; i++) { + result += skip(keyType); + result += skip(valType); + } + result += readMapEnd(); + return result; + } + case T_SET: + { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += readSetBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(elemType); + } + result += readSetEnd(); + return result; + } + case T_LIST: + { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += readListBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(elemType); + } + result += readListEnd(); + return result; + } + default: + return 0; + } + } + + inline boost::shared_ptr<TTransport> getTransport() { + return ptrans_; + } + + // TODO: remove these two calls, they are for backwards + // compatibility + inline boost::shared_ptr<TTransport> getInputTransport() { + return ptrans_; + } + inline boost::shared_ptr<TTransport> getOutputTransport() { + return ptrans_; + } + + protected: + TProtocol(boost::shared_ptr<TTransport> ptrans): + ptrans_(ptrans) { + trans_ = ptrans.get(); + } + + boost::shared_ptr<TTransport> ptrans_; + TTransport* trans_; + + private: + TProtocol() {} +}; + +/** + * Constructs input and output protocol objects given transports. + */ +class TProtocolFactory { + public: + TProtocolFactory() {} + + virtual ~TProtocolFactory() {} + + virtual boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) = 0; +}; + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1 diff --git a/lib/cpp/src/protocol/TProtocolException.h b/lib/cpp/src/protocol/TProtocolException.h new file mode 100644 index 000000000..33011b379 --- /dev/null +++ b/lib/cpp/src/protocol/TProtocolException.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ +#define _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ 1 + +#include <string> + +namespace apache { namespace thrift { namespace protocol { + +/** + * Class to encapsulate all the possible types of protocol errors that may + * occur in various protocol systems. This provides a sort of generic + * wrapper around the shitty UNIX E_ error codes that lets a common code + * base of error handling to be used for various types of protocols, i.e. + * pipes etc. + * + */ +class TProtocolException : public apache::thrift::TException { + public: + + /** + * Error codes for the various types of exceptions. + */ + enum TProtocolExceptionType + { UNKNOWN = 0 + , INVALID_DATA = 1 + , NEGATIVE_SIZE = 2 + , SIZE_LIMIT = 3 + , BAD_VERSION = 4 + , NOT_IMPLEMENTED = 5 + }; + + TProtocolException() : + apache::thrift::TException(), + type_(UNKNOWN) {} + + TProtocolException(TProtocolExceptionType type) : + apache::thrift::TException(), + type_(type) {} + + TProtocolException(const std::string& message) : + apache::thrift::TException(message), + type_(UNKNOWN) {} + + TProtocolException(TProtocolExceptionType type, const std::string& message) : + apache::thrift::TException(message), + type_(type) {} + + virtual ~TProtocolException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TProtocolExceptionType getType() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TProtocolException: Unknown protocol exception"; + case INVALID_DATA : return "TProtocolException: Invalid data"; + case NEGATIVE_SIZE : return "TProtocolException: Negative size"; + case SIZE_LIMIT : return "TProtocolException: Exceeded size limit"; + case BAD_VERSION : return "TProtocolException: Invalid version"; + case NOT_IMPLEMENTED : return "TProtocolException: Not implemented"; + default : return "TProtocolException: (Invalid exception type)"; + } + } else { + return message_.c_str(); + } + } + + protected: + /** + * Error code + */ + TProtocolExceptionType type_; + +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ diff --git a/lib/cpp/src/protocol/TProtocolTap.h b/lib/cpp/src/protocol/TProtocolTap.h new file mode 100644 index 000000000..5580216a3 --- /dev/null +++ b/lib/cpp/src/protocol/TProtocolTap.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ +#define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1 + +#include <protocol/TOneWayProtocol.h> + +namespace apache { namespace thrift { namespace protocol { + +using apache::thrift::transport::TTransport; + +/** + * Puts a wiretap on a protocol object. Any reads to this class are passed + * through to an enclosed protocol object, but also mirrored as write to a + * second protocol object. + * + */ +class TProtocolTap : public TReadOnlyProtocol { + public: + TProtocolTap(boost::shared_ptr<TProtocol> source, + boost::shared_ptr<TProtocol> sink) + : TReadOnlyProtocol(source->getTransport(), "TProtocolTap") + , source_(source) + , sink_(sink) + {} + + virtual uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t rv = source_->readMessageBegin(name, messageType, seqid); + sink_->writeMessageBegin(name, messageType, seqid); + return rv; + } + + virtual uint32_t readMessageEnd() { + uint32_t rv = source_->readMessageEnd(); + sink_->writeMessageEnd(); + return rv; + } + + virtual uint32_t readStructBegin(std::string& name) { + uint32_t rv = source_->readStructBegin(name); + sink_->writeStructBegin(name.c_str()); + return rv; + } + + virtual uint32_t readStructEnd() { + uint32_t rv = source_->readStructEnd(); + sink_->writeStructEnd(); + return rv; + } + + virtual uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t rv = source_->readFieldBegin(name, fieldType, fieldId); + if (fieldType == T_STOP) { + sink_->writeFieldStop(); + } else { + sink_->writeFieldBegin(name.c_str(), fieldType, fieldId); + } + return rv; + } + + + virtual uint32_t readFieldEnd() { + uint32_t rv = source_->readFieldEnd(); + sink_->writeFieldEnd(); + return rv; + } + + virtual uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint32_t rv = source_->readMapBegin(keyType, valType, size); + sink_->writeMapBegin(keyType, valType, size); + return rv; + } + + + virtual uint32_t readMapEnd() { + uint32_t rv = source_->readMapEnd(); + sink_->writeMapEnd(); + return rv; + } + + virtual uint32_t readListBegin(TType& elemType, + uint32_t& size) { + uint32_t rv = source_->readListBegin(elemType, size); + sink_->writeListBegin(elemType, size); + return rv; + } + + + virtual uint32_t readListEnd() { + uint32_t rv = source_->readListEnd(); + sink_->writeListEnd(); + return rv; + } + + virtual uint32_t readSetBegin(TType& elemType, + uint32_t& size) { + uint32_t rv = source_->readSetBegin(elemType, size); + sink_->writeSetBegin(elemType, size); + return rv; + } + + + virtual uint32_t readSetEnd() { + uint32_t rv = source_->readSetEnd(); + sink_->writeSetEnd(); + return rv; + } + + virtual uint32_t readBool(bool& value) { + uint32_t rv = source_->readBool(value); + sink_->writeBool(value); + return rv; + } + + virtual uint32_t readByte(int8_t& byte) { + uint32_t rv = source_->readByte(byte); + sink_->writeByte(byte); + return rv; + } + + virtual uint32_t readI16(int16_t& i16) { + uint32_t rv = source_->readI16(i16); + sink_->writeI16(i16); + return rv; + } + + virtual uint32_t readI32(int32_t& i32) { + uint32_t rv = source_->readI32(i32); + sink_->writeI32(i32); + return rv; + } + + virtual uint32_t readI64(int64_t& i64) { + uint32_t rv = source_->readI64(i64); + sink_->writeI64(i64); + return rv; + } + + virtual uint32_t readDouble(double& dub) { + uint32_t rv = source_->readDouble(dub); + sink_->writeDouble(dub); + return rv; + } + + virtual uint32_t readString(std::string& str) { + uint32_t rv = source_->readString(str); + sink_->writeString(str); + return rv; + } + + virtual uint32_t readBinary(std::string& str) { + uint32_t rv = source_->readBinary(str); + sink_->writeBinary(str); + return rv; + } + + private: + boost::shared_ptr<TProtocol> source_; + boost::shared_ptr<TProtocol> sink_; +}; + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1 diff --git a/lib/cpp/src/server/TNonblockingServer.cpp b/lib/cpp/src/server/TNonblockingServer.cpp new file mode 100644 index 000000000..45f635cbe --- /dev/null +++ b/lib/cpp/src/server/TNonblockingServer.cpp @@ -0,0 +1,750 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TNonblockingServer.h" +#include <concurrency/Exception.h> + +#include <iostream> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <netdb.h> +#include <fcntl.h> +#include <errno.h> +#include <assert.h> + +namespace apache { namespace thrift { namespace server { + +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::concurrency; +using namespace std; + +class TConnection::Task: public Runnable { + public: + Task(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocol> input, + boost::shared_ptr<TProtocol> output, + int taskHandle) : + processor_(processor), + input_(input), + output_(output), + taskHandle_(taskHandle) {} + + void run() { + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + cerr << "TNonblockingServer client died: " << ttx.what() << endl; + } catch (TException& x) { + cerr << "TNonblockingServer exception: " << x.what() << endl; + } catch (...) { + cerr << "TNonblockingServer uncaught exception." << endl; + } + + // Signal completion back to the libevent thread via a socketpair + int8_t b = 0; + if (-1 == send(taskHandle_, &b, sizeof(int8_t), 0)) { + GlobalOutput.perror("TNonblockingServer::Task: send ", errno); + } + if (-1 == ::close(taskHandle_)) { + GlobalOutput.perror("TNonblockingServer::Task: close, possible resource leak ", errno); + } + } + + private: + boost::shared_ptr<TProcessor> processor_; + boost::shared_ptr<TProtocol> input_; + boost::shared_ptr<TProtocol> output_; + int taskHandle_; +}; + +void TConnection::init(int socket, short eventFlags, TNonblockingServer* s) { + socket_ = socket; + server_ = s; + appState_ = APP_INIT; + eventFlags_ = 0; + + readBufferPos_ = 0; + readWant_ = 0; + + writeBuffer_ = NULL; + writeBufferSize_ = 0; + writeBufferPos_ = 0; + + socketState_ = SOCKET_RECV; + appState_ = APP_INIT; + + taskHandle_ = -1; + + // Set flags, which also registers the event + setFlags(eventFlags); + + // get input/transports + factoryInputTransport_ = s->getInputTransportFactory()->getTransport(inputTransport_); + factoryOutputTransport_ = s->getOutputTransportFactory()->getTransport(outputTransport_); + + // Create protocol + inputProtocol_ = s->getInputProtocolFactory()->getProtocol(factoryInputTransport_); + outputProtocol_ = s->getOutputProtocolFactory()->getProtocol(factoryOutputTransport_); +} + +void TConnection::workSocket() { + int flags=0, got=0, left=0, sent=0; + uint32_t fetch = 0; + + switch (socketState_) { + case SOCKET_RECV: + // It is an error to be in this state if we already have all the data + assert(readBufferPos_ < readWant_); + + // Double the buffer size until it is big enough + if (readWant_ > readBufferSize_) { + while (readWant_ > readBufferSize_) { + readBufferSize_ *= 2; + } + readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_); + if (readBuffer_ == NULL) { + GlobalOutput("TConnection::workSocket() realloc"); + close(); + return; + } + } + + // Read from the socket + fetch = readWant_ - readBufferPos_; + got = recv(socket_, readBuffer_ + readBufferPos_, fetch, 0); + + if (got > 0) { + // Move along in the buffer + readBufferPos_ += got; + + // Check that we did not overdo it + assert(readBufferPos_ <= readWant_); + + // We are done reading, move onto the next state + if (readBufferPos_ == readWant_) { + transition(); + } + return; + } else if (got == -1) { + // Blocking errors are okay, just move on + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + + if (errno != ECONNRESET) { + GlobalOutput.perror("TConnection::workSocket() recv -1 ", errno); + } + } + + // Whenever we get down here it means a remote disconnect + close(); + + return; + + case SOCKET_SEND: + // Should never have position past size + assert(writeBufferPos_ <= writeBufferSize_); + + // If there is no data to send, then let us move on + if (writeBufferPos_ == writeBufferSize_) { + GlobalOutput("WARNING: Send state with no data to send\n"); + transition(); + return; + } + + flags = 0; + #ifdef MSG_NOSIGNAL + // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we + // check for the EPIPE return condition and close the socket in that case + flags |= MSG_NOSIGNAL; + #endif // ifdef MSG_NOSIGNAL + + left = writeBufferSize_ - writeBufferPos_; + sent = send(socket_, writeBuffer_ + writeBufferPos_, left, flags); + + if (sent <= 0) { + // Blocking errors are okay, just move on + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + if (errno != EPIPE) { + GlobalOutput.perror("TConnection::workSocket() send -1 ", errno); + } + close(); + return; + } + + writeBufferPos_ += sent; + + // Did we overdo it? + assert(writeBufferPos_ <= writeBufferSize_); + + // We are done! + if (writeBufferPos_ == writeBufferSize_) { + transition(); + } + + return; + + default: + GlobalOutput.printf("Shit Got Ill. Socket State %d", socketState_); + assert(0); + } +} + +/** + * This is called when the application transitions from one state into + * another. This means that it has finished writing the data that it needed + * to, or finished receiving the data that it needed to. + */ +void TConnection::transition() { + + int sz = 0; + + // Switch upon the state that we are currently in and move to a new state + switch (appState_) { + + case APP_READ_REQUEST: + // We are done reading the request, package the read buffer into transport + // and get back some data from the dispatch function + // If we've used these transport buffers enough times, reset them to avoid bloating + + inputTransport_->resetBuffer(readBuffer_, readBufferPos_); + ++numReadsSinceReset_; + if (numWritesSinceReset_ < 512) { + outputTransport_->resetBuffer(); + } else { + // reset the capacity of the output transport if we used it enough times that it might be bloated + try { + outputTransport_->resetBuffer(true); + numWritesSinceReset_ = 0; + } catch (TTransportException &ttx) { + GlobalOutput.printf("TTransportException: TMemoryBuffer::resetBuffer() %s", ttx.what()); + close(); + return; + } + } + + // Prepend four bytes of blank space to the buffer so we can + // write the frame size there later. + outputTransport_->getWritePtr(4); + outputTransport_->wroteBytes(4); + + if (server_->isThreadPoolProcessing()) { + // We are setting up a Task to do this work and we will wait on it + int sv[2]; + if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TConnection::socketpair() failed ", errno); + // Now we will fall through to the APP_WAIT_TASK block with no response + } else { + // Create task and dispatch to the thread manager + boost::shared_ptr<Runnable> task = + boost::shared_ptr<Runnable>(new Task(server_->getProcessor(), + inputProtocol_, + outputProtocol_, + sv[1])); + // The application is now waiting on the task to finish + appState_ = APP_WAIT_TASK; + + // Create an event to be notified when the task finishes + event_set(&taskEvent_, + taskHandle_ = sv[0], + EV_READ, + TConnection::taskHandler, + this); + + // Attach to the base + event_base_set(server_->getEventBase(), &taskEvent_); + + // Add the event and start up the server + if (-1 == event_add(&taskEvent_, 0)) { + GlobalOutput("TNonblockingServer::serve(): coult not event_add"); + return; + } + try { + server_->addTask(task); + } catch (IllegalStateException & ise) { + // The ThreadManager is not ready to handle any more tasks (it's probably shutting down). + GlobalOutput.printf("IllegalStateException: Server::process() %s", ise.what()); + close(); + } + + // Set this connection idle so that libevent doesn't process more + // data on it while we're still waiting for the threadmanager to + // finish this task + setIdle(); + return; + } + } else { + try { + // Invoke the processor + server_->getProcessor()->process(inputProtocol_, outputProtocol_); + } catch (TTransportException &ttx) { + GlobalOutput.printf("TTransportException: Server::process() %s", ttx.what()); + close(); + return; + } catch (TException &x) { + GlobalOutput.printf("TException: Server::process() %s", x.what()); + close(); + return; + } catch (...) { + GlobalOutput.printf("Server::process() unknown exception"); + close(); + return; + } + } + + // Intentionally fall through here, the call to process has written into + // the writeBuffer_ + + case APP_WAIT_TASK: + // We have now finished processing a task and the result has been written + // into the outputTransport_, so we grab its contents and place them into + // the writeBuffer_ for actual writing by the libevent thread + + // Get the result of the operation + outputTransport_->getBuffer(&writeBuffer_, &writeBufferSize_); + + // If the function call generated return data, then move into the send + // state and get going + // 4 bytes were reserved for frame size + if (writeBufferSize_ > 4) { + + // Move into write state + writeBufferPos_ = 0; + socketState_ = SOCKET_SEND; + + // Put the frame size into the write buffer + int32_t frameSize = (int32_t)htonl(writeBufferSize_ - 4); + memcpy(writeBuffer_, &frameSize, 4); + + // Socket into write mode + appState_ = APP_SEND_RESULT; + setWrite(); + + // Try to work the socket immediately + // workSocket(); + + return; + } + + // In this case, the request was oneway and we should fall through + // right back into the read frame header state + goto LABEL_APP_INIT; + + case APP_SEND_RESULT: + + ++numWritesSinceReset_; + + // N.B.: We also intentionally fall through here into the INIT state! + + LABEL_APP_INIT: + case APP_INIT: + + // reset the input buffer if we used it enough times that it might be bloated + if (numReadsSinceReset_ > 512) + { + void * new_buffer = std::realloc(readBuffer_, 1024); + if (new_buffer == NULL) { + GlobalOutput("TConnection::transition() realloc"); + close(); + return; + } + readBuffer_ = (uint8_t*) new_buffer; + readBufferSize_ = 1024; + numReadsSinceReset_ = 0; + } + + // Clear write buffer variables + writeBuffer_ = NULL; + writeBufferPos_ = 0; + writeBufferSize_ = 0; + + // Set up read buffer for getting 4 bytes + readBufferPos_ = 0; + readWant_ = 4; + + // Into read4 state we go + socketState_ = SOCKET_RECV; + appState_ = APP_READ_FRAME_SIZE; + + // Register read event + setRead(); + + // Try to work the socket right away + // workSocket(); + + return; + + case APP_READ_FRAME_SIZE: + // We just read the request length, deserialize it + sz = *(int32_t*)readBuffer_; + sz = (int32_t)ntohl(sz); + + if (sz <= 0) { + GlobalOutput.printf("TConnection:transition() Negative frame size %d, remote side not using TFramedTransport?", sz); + close(); + return; + } + + // Reset the read buffer + readWant_ = (uint32_t)sz; + readBufferPos_= 0; + + // Move into read request state + appState_ = APP_READ_REQUEST; + + // Work the socket right away + // workSocket(); + + return; + + default: + GlobalOutput.printf("Totally Fucked. Application State %d", appState_); + assert(0); + } +} + +void TConnection::setFlags(short eventFlags) { + // Catch the do nothing case + if (eventFlags_ == eventFlags) { + return; + } + + // Delete a previously existing event + if (eventFlags_ != 0) { + if (event_del(&event_) == -1) { + GlobalOutput("TConnection::setFlags event_del"); + return; + } + } + + // Update in memory structure + eventFlags_ = eventFlags; + + // Do not call event_set if there are no flags + if (!eventFlags_) { + return; + } + + /** + * event_set: + * + * Prepares the event structure &event to be used in future calls to + * event_add() and event_del(). The event will be prepared to call the + * eventHandler using the 'sock' file descriptor to monitor events. + * + * The events can be either EV_READ, EV_WRITE, or both, indicating + * that an application can read or write from the file respectively without + * blocking. + * + * The eventHandler will be called with the file descriptor that triggered + * the event and the type of event which will be one of: EV_TIMEOUT, + * EV_SIGNAL, EV_READ, EV_WRITE. + * + * The additional flag EV_PERSIST makes an event_add() persistent until + * event_del() has been called. + * + * Once initialized, the &event struct can be used repeatedly with + * event_add() and event_del() and does not need to be reinitialized unless + * the eventHandler and/or the argument to it are to be changed. However, + * when an ev structure has been added to libevent using event_add() the + * structure must persist until the event occurs (assuming EV_PERSIST + * is not set) or is removed using event_del(). You may not reuse the same + * ev structure for multiple monitored descriptors; each descriptor needs + * its own ev. + */ + event_set(&event_, socket_, eventFlags_, TConnection::eventHandler, this); + event_base_set(server_->getEventBase(), &event_); + + // Add the event + if (event_add(&event_, 0) == -1) { + GlobalOutput("TConnection::setFlags(): could not event_add"); + } +} + +/** + * Closes a connection + */ +void TConnection::close() { + // Delete the registered libevent + if (event_del(&event_) == -1) { + GlobalOutput("TConnection::close() event_del"); + } + + // Close the socket + if (socket_ > 0) { + ::close(socket_); + } + socket_ = 0; + + // close any factory produced transports + factoryInputTransport_->close(); + factoryOutputTransport_->close(); + + // Give this object back to the server that owns it + server_->returnConnection(this); +} + +void TConnection::checkIdleBufferMemLimit(uint32_t limit) { + if (readBufferSize_ > limit) { + readBufferSize_ = limit; + readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_); + if (readBuffer_ == NULL) { + GlobalOutput("TConnection::checkIdleBufferMemLimit() realloc"); + close(); + } + } +} + +/** + * Creates a new connection either by reusing an object off the stack or + * by allocating a new one entirely + */ +TConnection* TNonblockingServer::createConnection(int socket, short flags) { + // Check the stack + if (connectionStack_.empty()) { + return new TConnection(socket, flags, this); + } else { + TConnection* result = connectionStack_.top(); + connectionStack_.pop(); + result->init(socket, flags, this); + return result; + } +} + +/** + * Returns a connection to the stack + */ +void TNonblockingServer::returnConnection(TConnection* connection) { + if (connectionStackLimit_ && + (connectionStack_.size() >= connectionStackLimit_)) { + delete connection; + } else { + connection->checkIdleBufferMemLimit(idleBufferMemLimit_); + connectionStack_.push(connection); + } +} + +/** + * Server socket had something happen. We accept all waiting client + * connections on fd and assign TConnection objects to handle those requests. + */ +void TNonblockingServer::handleEvent(int fd, short which) { + // Make sure that libevent didn't fuck up the socket handles + assert(fd == serverSocket_); + + // Server socket accepted a new connection + socklen_t addrLen; + struct sockaddr addr; + addrLen = sizeof(addr); + + // Going to accept a new client socket + int clientSocket; + + // Accept as many new clients as possible, even though libevent signaled only + // one, this helps us to avoid having to go back into the libevent engine so + // many times + while ((clientSocket = accept(fd, &addr, &addrLen)) != -1) { + + // Explicitly set this socket to NONBLOCK mode + int flags; + if ((flags = fcntl(clientSocket, F_GETFL, 0)) < 0 || + fcntl(clientSocket, F_SETFL, flags | O_NONBLOCK) < 0) { + GlobalOutput.perror("thriftServerEventHandler: set O_NONBLOCK (fcntl) ", errno); + close(clientSocket); + return; + } + + // Create a new TConnection for this client socket. + TConnection* clientConnection = + createConnection(clientSocket, EV_READ | EV_PERSIST); + + // Fail fast if we could not create a TConnection object + if (clientConnection == NULL) { + GlobalOutput.printf("thriftServerEventHandler: failed TConnection factory"); + close(clientSocket); + return; + } + + // Put this client connection into the proper state + clientConnection->transition(); + } + + // Done looping accept, now we have to make sure the error is due to + // blocking. Any other error is a problem + if (errno != EAGAIN && errno != EWOULDBLOCK) { + GlobalOutput.perror("thriftServerEventHandler: accept() ", errno); + } +} + +/** + * Creates a socket to listen on and binds it to the local port. + */ +void TNonblockingServer::listenSocket() { + int s; + struct addrinfo hints, *res, *res0; + int error; + + char port[sizeof("65536") + 1]; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + // Wildcard address + error = getaddrinfo(NULL, port, &hints, &res0); + if (error) { + string errStr = "TNonblockingServer::serve() getaddrinfo " + string(gai_strerror(error)); + GlobalOutput(errStr.c_str()); + return; + } + + // Pick the ipv6 address first since ipv4 addresses can be mapped + // into ipv6 space. + for (res = res0; res; res = res->ai_next) { + if (res->ai_family == AF_INET6 || res->ai_next == NULL) + break; + } + + // Create the server socket + s = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (s == -1) { + freeaddrinfo(res0); + throw TException("TNonblockingServer::serve() socket() -1"); + } + + #ifdef IPV6_V6ONLY + int zero = 0; + if (-1 == setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &zero, sizeof(zero))) { + GlobalOutput("TServerSocket::listen() IPV6_V6ONLY"); + } + #endif // #ifdef IPV6_V6ONLY + + + int one = 1; + + // Set reuseaddr to avoid 2MSL delay on server restart + setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + if (bind(s, res->ai_addr, res->ai_addrlen) == -1) { + close(s); + freeaddrinfo(res0); + throw TException("TNonblockingServer::serve() bind"); + } + + // Done with the addr info + freeaddrinfo(res0); + + // Set up this file descriptor for listening + listenSocket(s); +} + +/** + * Takes a socket created by listenSocket() and sets various options on it + * to prepare for use in the server. + */ +void TNonblockingServer::listenSocket(int s) { + // Set socket to nonblocking mode + int flags; + if ((flags = fcntl(s, F_GETFL, 0)) < 0 || + fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) { + close(s); + throw TException("TNonblockingServer::serve() O_NONBLOCK"); + } + + int one = 1; + struct linger ling = {0, 0}; + + // Keepalive to ensure full result flushing + setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one)); + + // Turn linger off to avoid hung sockets + setsockopt(s, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)); + + // Set TCP nodelay if available, MAC OS X Hack + // See http://lists.danga.com/pipermail/memcached/2005-March/001240.html + #ifndef TCP_NOPUSH + setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); + #endif + + if (listen(s, LISTEN_BACKLOG) == -1) { + close(s); + throw TException("TNonblockingServer::serve() listen"); + } + + // Cool, this socket is good to go, set it as the serverSocket_ + serverSocket_ = s; +} + +/** + * Register the core libevent events onto the proper base. + */ +void TNonblockingServer::registerEvents(event_base* base) { + assert(serverSocket_ != -1); + assert(!eventBase_); + eventBase_ = base; + + // Print some libevent stats + GlobalOutput.printf("libevent %s method %s", + event_get_version(), + event_get_method()); + + // Register the server event + event_set(&serverEvent_, + serverSocket_, + EV_READ | EV_PERSIST, + TNonblockingServer::eventHandler, + this); + event_base_set(eventBase_, &serverEvent_); + + // Add the event and start up the server + if (-1 == event_add(&serverEvent_, 0)) { + throw TException("TNonblockingServer::serve(): coult not event_add"); + } +} + +/** + * Main workhorse function, starts up the server listening on a port and + * loops over the libevent handler. + */ +void TNonblockingServer::serve() { + // Init socket + listenSocket(); + + // Initialize libevent core + registerEvents(static_cast<event_base*>(event_init())); + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + // Run libevent engine, never returns, invokes calls to eventHandler + event_base_loop(eventBase_, 0); +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TNonblockingServer.h b/lib/cpp/src/server/TNonblockingServer.h new file mode 100644 index 000000000..1684b64a0 --- /dev/null +++ b/lib/cpp/src/server/TNonblockingServer.h @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TNONBLOCKINGSERVER_H_ +#define _THRIFT_SERVER_TNONBLOCKINGSERVER_H_ 1 + +#include <Thrift.h> +#include <server/TServer.h> +#include <transport/TBufferTransports.h> +#include <concurrency/ThreadManager.h> +#include <stack> +#include <string> +#include <errno.h> +#include <cstdlib> +#include <event.h> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::transport::TMemoryBuffer; +using apache::thrift::protocol::TProtocol; +using apache::thrift::concurrency::Runnable; +using apache::thrift::concurrency::ThreadManager; + +// Forward declaration of class +class TConnection; + +/** + * This is a non-blocking server in C++ for high performance that operates a + * single IO thread. It assumes that all incoming requests are framed with a + * 4 byte length indicator and writes out responses using the same framing. + * + * It does not use the TServerTransport framework, but rather has socket + * operations hardcoded for use with select. + * + */ +class TNonblockingServer : public TServer { + private: + + // Listen backlog + static const int LISTEN_BACKLOG = 1024; + + // Default limit on size of idle connection pool + static const size_t CONNECTION_STACK_LIMIT = 1024; + + // Maximum size of buffer allocated to idle connection + static const uint32_t IDLE_BUFFER_MEM_LIMIT = 8192; + + // Server socket file descriptor + int serverSocket_; + + // Port server runs on + int port_; + + // For processing via thread pool, may be NULL + boost::shared_ptr<ThreadManager> threadManager_; + + // Is thread pool processing? + bool threadPoolProcessing_; + + // The event base for libevent + event_base* eventBase_; + + // Event struct, for use with eventBase_ + struct event serverEvent_; + + // Number of TConnection object we've created + size_t numTConnections_; + + // Limit for how many TConnection objects to cache + size_t connectionStackLimit_; + + /** + * Max read buffer size for an idle connection. When we place an idle + * TConnection into connectionStack_, we insure that its read buffer is + * reduced to this size to insure that idle connections don't hog memory. + */ + uint32_t idleBufferMemLimit_; + + /** + * This is a stack of all the objects that have been created but that + * are NOT currently in use. When we close a connection, we place it on this + * stack so that the object can be reused later, rather than freeing the + * memory and reallocating a new object later. + */ + std::stack<TConnection*> connectionStack_; + + void handleEvent(int fd, short which); + + public: + TNonblockingServer(boost::shared_ptr<TProcessor> processor, + int port) : + TServer(processor), + serverSocket_(-1), + port_(port), + threadPoolProcessing_(false), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) {} + + TNonblockingServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + int port, + boost::shared_ptr<ThreadManager> threadManager = boost::shared_ptr<ThreadManager>()) : + TServer(processor), + serverSocket_(-1), + port_(port), + threadManager_(threadManager), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) { + setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setInputProtocolFactory(protocolFactory); + setOutputProtocolFactory(protocolFactory); + setThreadManager(threadManager); + } + + TNonblockingServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory, + int port, + boost::shared_ptr<ThreadManager> threadManager = boost::shared_ptr<ThreadManager>()) : + TServer(processor), + serverSocket_(0), + port_(port), + threadManager_(threadManager), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) { + setInputTransportFactory(inputTransportFactory); + setOutputTransportFactory(outputTransportFactory); + setInputProtocolFactory(inputProtocolFactory); + setOutputProtocolFactory(outputProtocolFactory); + setThreadManager(threadManager); + } + + ~TNonblockingServer() {} + + void setThreadManager(boost::shared_ptr<ThreadManager> threadManager) { + threadManager_ = threadManager; + threadPoolProcessing_ = (threadManager != NULL); + } + + boost::shared_ptr<ThreadManager> getThreadManager() { + return threadManager_; + } + + /** + * Get the maximum number of unused TConnection we will hold in reserve. + * + * @return the current limit on TConnection pool size. + */ + size_t getConnectionStackLimit() const { + return connectionStackLimit_; + } + + /** + * Set the maximum number of unused TConnection we will hold in reserve. + * + * @param sz the new limit for TConnection pool size. + */ + void setConnectionStackLimit(size_t sz) { + connectionStackLimit_ = sz; + } + + bool isThreadPoolProcessing() const { + return threadPoolProcessing_; + } + + void addTask(boost::shared_ptr<Runnable> task) { + threadManager_->add(task); + } + + event_base* getEventBase() const { + return eventBase_; + } + + void incrementNumConnections() { + ++numTConnections_; + } + + void decrementNumConnections() { + --numTConnections_; + } + + size_t getNumConnections() { + return numTConnections_; + } + + size_t getNumIdleConnections() { + return connectionStack_.size(); + } + + /** + * Get the maximum limit of memory allocated to idle TConnection objects. + * + * @return # bytes beyond which we will shrink buffers when idle. + */ + size_t getIdleBufferMemLimit() const { + return idleBufferMemLimit_; + } + + /** + * Set the maximum limit of memory allocated to idle TConnection objects. + * If a TConnection object goes idle with more than this much memory + * allocated to its buffer, we shrink it to this value. + * + * @param limit of bytes beyond which we will shrink buffers when idle. + */ + void setIdleBufferMemLimit(size_t limit) { + idleBufferMemLimit_ = limit; + } + + TConnection* createConnection(int socket, short flags); + + void returnConnection(TConnection* connection); + + static void eventHandler(int fd, short which, void* v) { + ((TNonblockingServer*)v)->handleEvent(fd, which); + } + + void listenSocket(); + + void listenSocket(int fd); + + void registerEvents(event_base* base); + + void serve(); +}; + +/** + * Two states for sockets, recv and send mode + */ +enum TSocketState { + SOCKET_RECV, + SOCKET_SEND +}; + +/** + * Four states for the nonblocking servr: + * 1) initialize + * 2) read 4 byte frame size + * 3) read frame of data + * 4) send back data (if any) + */ +enum TAppState { + APP_INIT, + APP_READ_FRAME_SIZE, + APP_READ_REQUEST, + APP_WAIT_TASK, + APP_SEND_RESULT +}; + +/** + * Represents a connection that is handled via libevent. This connection + * essentially encapsulates a socket that has some associated libevent state. + */ +class TConnection { + private: + + class Task; + + // Server handle + TNonblockingServer* server_; + + // Socket handle + int socket_; + + // Libevent object + struct event event_; + + // Libevent flags + short eventFlags_; + + // Socket mode + TSocketState socketState_; + + // Application state + TAppState appState_; + + // How much data needed to read + uint32_t readWant_; + + // Where in the read buffer are we + uint32_t readBufferPos_; + + // Read buffer + uint8_t* readBuffer_; + + // Read buffer size + uint32_t readBufferSize_; + + // Write buffer + uint8_t* writeBuffer_; + + // Write buffer size + uint32_t writeBufferSize_; + + // How far through writing are we? + uint32_t writeBufferPos_; + + // How many times have we read since our last buffer reset? + uint32_t numReadsSinceReset_; + + // How many times have we written since our last buffer reset? + uint32_t numWritesSinceReset_; + + // Task handle + int taskHandle_; + + // Task event + struct event taskEvent_; + + // Transport to read from + boost::shared_ptr<TMemoryBuffer> inputTransport_; + + // Transport that processor writes to + boost::shared_ptr<TMemoryBuffer> outputTransport_; + + // extra transport generated by transport factory (e.g. BufferedRouterTransport) + boost::shared_ptr<TTransport> factoryInputTransport_; + boost::shared_ptr<TTransport> factoryOutputTransport_; + + // Protocol decoder + boost::shared_ptr<TProtocol> inputProtocol_; + + // Protocol encoder + boost::shared_ptr<TProtocol> outputProtocol_; + + // Go into read mode + void setRead() { + setFlags(EV_READ | EV_PERSIST); + } + + // Go into write mode + void setWrite() { + setFlags(EV_WRITE | EV_PERSIST); + } + + // Set socket idle + void setIdle() { + setFlags(0); + } + + // Set event flags + void setFlags(short eventFlags); + + // Libevent handlers + void workSocket(); + + // Close this client and reset + void close(); + + public: + + // Constructor + TConnection(int socket, short eventFlags, TNonblockingServer *s) { + readBuffer_ = (uint8_t*)std::malloc(1024); + if (readBuffer_ == NULL) { + throw new apache::thrift::TException("Out of memory."); + } + readBufferSize_ = 1024; + + numReadsSinceReset_ = 0; + numWritesSinceReset_ = 0; + + // Allocate input and output tranpsorts + // these only need to be allocated once per TConnection (they don't need to be + // reallocated on init() call) + inputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer(readBuffer_, readBufferSize_)); + outputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer()); + + init(socket, eventFlags, s); + server_->incrementNumConnections(); + } + + ~TConnection() { + server_->decrementNumConnections(); + } + + /** + * Check read buffer against a given limit and shrink it if exceeded. + * + * @param limit we limit buffer size to. + */ + void checkIdleBufferMemLimit(uint32_t limit); + + // Initialize + void init(int socket, short eventFlags, TNonblockingServer *s); + + // Transition into a new state + void transition(); + + // Handler wrapper + static void eventHandler(int fd, short /* which */, void* v) { + assert(fd == ((TConnection*)v)->socket_); + ((TConnection*)v)->workSocket(); + } + + // Handler wrapper for task block + static void taskHandler(int fd, short /* which */, void* v) { + assert(fd == ((TConnection*)v)->taskHandle_); + if (-1 == ::close(((TConnection*)v)->taskHandle_)) { + GlobalOutput.perror("TConnection::taskHandler close handle failed, resource leak ", errno); + } + ((TConnection*)v)->transition(); + } + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ diff --git a/lib/cpp/src/server/TServer.cpp b/lib/cpp/src/server/TServer.cpp new file mode 100644 index 000000000..6b692ab02 --- /dev/null +++ b/lib/cpp/src/server/TServer.cpp @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <sys/time.h> +#include <sys/resource.h> +#include <unistd.h> + +namespace apache { namespace thrift { namespace server { + +int increase_max_fds(int max_fds=(1<<24)) { + struct rlimit fdmaxrl; + + for(fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds; + max_fds && (setrlimit(RLIMIT_NOFILE, &fdmaxrl) < 0); + fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds) { + max_fds /= 2; + } + + return fdmaxrl.rlim_cur; +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TServer.h b/lib/cpp/src/server/TServer.h new file mode 100644 index 000000000..5c4c588d4 --- /dev/null +++ b/lib/cpp/src/server/TServer.h @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSERVER_H_ +#define _THRIFT_SERVER_TSERVER_H_ 1 + +#include <TProcessor.h> +#include <transport/TServerTransport.h> +#include <protocol/TBinaryProtocol.h> +#include <concurrency/Thread.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::TProcessor; +using apache::thrift::protocol::TBinaryProtocolFactory; +using apache::thrift::protocol::TProtocol; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransport; +using apache::thrift::transport::TTransportFactory; + +/** + * Virtual interface class that can handle events from the server core. To + * use this you should subclass it and implement the methods that you care + * about. Your subclass can also store local data that you may care about, + * such as additional "arguments" to these methods (stored in the object + * instance's state). + */ +class TServerEventHandler { + public: + + virtual ~TServerEventHandler() {} + + /** + * Called before the server begins. + */ + virtual void preServe() {} + + /** + * Called when a new client has connected and is about to being processing. + */ + virtual void clientBegin(boost::shared_ptr<TProtocol> /* input */, + boost::shared_ptr<TProtocol> /* output */) {} + + /** + * Called when a client has finished making requests. + */ + virtual void clientEnd(boost::shared_ptr<TProtocol> /* input */, + boost::shared_ptr<TProtocol> /* output */) {} + + protected: + + /** + * Prevent direct instantiation. + */ + TServerEventHandler() {} + +}; + +/** + * Thrift server. + * + */ +class TServer : public concurrency::Runnable { + public: + + virtual ~TServer() {} + + virtual void serve() = 0; + + virtual void stop() {} + + // Allows running the server as a Runnable thread + virtual void run() { + serve(); + } + + boost::shared_ptr<TProcessor> getProcessor() { + return processor_; + } + + boost::shared_ptr<TServerTransport> getServerTransport() { + return serverTransport_; + } + + boost::shared_ptr<TTransportFactory> getInputTransportFactory() { + return inputTransportFactory_; + } + + boost::shared_ptr<TTransportFactory> getOutputTransportFactory() { + return outputTransportFactory_; + } + + boost::shared_ptr<TProtocolFactory> getInputProtocolFactory() { + return inputProtocolFactory_; + } + + boost::shared_ptr<TProtocolFactory> getOutputProtocolFactory() { + return outputProtocolFactory_; + } + + boost::shared_ptr<TServerEventHandler> getEventHandler() { + return eventHandler_; + } + +protected: + TServer(boost::shared_ptr<TProcessor> processor): + processor_(processor) { + setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setInputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + } + + TServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport): + processor_(processor), + serverTransport_(serverTransport) { + setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setInputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + } + + TServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory): + processor_(processor), + serverTransport_(serverTransport), + inputTransportFactory_(transportFactory), + outputTransportFactory_(transportFactory), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory) {} + + TServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory): + processor_(processor), + serverTransport_(serverTransport), + inputTransportFactory_(inputTransportFactory), + outputTransportFactory_(outputTransportFactory), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory) {} + + + // Class variables + boost::shared_ptr<TProcessor> processor_; + boost::shared_ptr<TServerTransport> serverTransport_; + + boost::shared_ptr<TTransportFactory> inputTransportFactory_; + boost::shared_ptr<TTransportFactory> outputTransportFactory_; + + boost::shared_ptr<TProtocolFactory> inputProtocolFactory_; + boost::shared_ptr<TProtocolFactory> outputProtocolFactory_; + + boost::shared_ptr<TServerEventHandler> eventHandler_; + +public: + void setInputTransportFactory(boost::shared_ptr<TTransportFactory> inputTransportFactory) { + inputTransportFactory_ = inputTransportFactory; + } + + void setOutputTransportFactory(boost::shared_ptr<TTransportFactory> outputTransportFactory) { + outputTransportFactory_ = outputTransportFactory; + } + + void setInputProtocolFactory(boost::shared_ptr<TProtocolFactory> inputProtocolFactory) { + inputProtocolFactory_ = inputProtocolFactory; + } + + void setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory> outputProtocolFactory) { + outputProtocolFactory_ = outputProtocolFactory; + } + + void setServerEventHandler(boost::shared_ptr<TServerEventHandler> eventHandler) { + eventHandler_ = eventHandler; + } + +}; + +/** + * Helper function to increase the max file descriptors limit + * for the current process and all of its children. + * By default, tries to increase it to as much as 2^24. + */ + int increase_max_fds(int max_fds=(1<<24)); + + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSERVER_H_ diff --git a/lib/cpp/src/server/TSimpleServer.cpp b/lib/cpp/src/server/TSimpleServer.cpp new file mode 100644 index 000000000..394ce21e2 --- /dev/null +++ b/lib/cpp/src/server/TSimpleServer.cpp @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TSimpleServer.h" +#include "transport/TTransportException.h" +#include <string> +#include <iostream> + +namespace apache { namespace thrift { namespace server { + +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using boost::shared_ptr; + +/** + * A simple single-threaded application server. Perfect for unit tests! + * + */ +void TSimpleServer::serve() { + + shared_ptr<TTransport> client; + shared_ptr<TTransport> inputTransport; + shared_ptr<TTransport> outputTransport; + shared_ptr<TProtocol> inputProtocol; + shared_ptr<TProtocol> outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + cerr << "TSimpleServer::run() listen(): " << ttx.what() << endl; + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + // Fetch client from server + while (!stop_) { + try { + client = serverTransport_->accept(); + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + if (eventHandler_ != NULL) { + eventHandler_->clientBegin(inputProtocol, outputProtocol); + } + try { + while (processor_->process(inputProtocol, outputProtocol)) { + // Peek ahead, is the remote side closed? + if (!inputTransport->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + cerr << "TSimpleServer client died: " << ttx.what() << endl; + } catch (TException& tx) { + cerr << "TSimpleServer exception: " << tx.what() << endl; + } + if (eventHandler_ != NULL) { + eventHandler_->clientEnd(inputProtocol, outputProtocol); + } + inputTransport->close(); + outputTransport->close(); + client->close(); + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "TServerTransport died on accept: " << ttx.what() << endl; + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "Some kind of accept exception: " << tx.what() << endl; + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "TThreadPoolServer: Unknown exception: " << s << endl; + break; + } + } + + if (stop_) { + try { + serverTransport_->close(); + } catch (TTransportException &ttx) { + cerr << "TServerTransport failed on close: " << ttx.what() << endl; + } + stop_ = false; + } +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TSimpleServer.h b/lib/cpp/src/server/TSimpleServer.h new file mode 100644 index 000000000..c4fc91c78 --- /dev/null +++ b/lib/cpp/src/server/TSimpleServer.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ +#define _THRIFT_SERVER_TSIMPLESERVER_H_ 1 + +#include "server/TServer.h" +#include "transport/TServerTransport.h" + +namespace apache { namespace thrift { namespace server { + +/** + * This is the most basic simple server. It is single-threaded and runs a + * continuous loop of accepting a single connection, processing requests on + * that connection until it closes, and then repeating. It is a good example + * of how to extend the TServer interface. + * + */ +class TSimpleServer : public TServer { + public: + TSimpleServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory) : + TServer(processor, serverTransport, transportFactory, protocolFactory), + stop_(false) {} + + TSimpleServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory): + TServer(processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + stop_(false) {} + + ~TSimpleServer() {} + + void serve(); + + void stop() { + stop_ = true; + } + + protected: + bool stop_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ diff --git a/lib/cpp/src/server/TThreadPoolServer.cpp b/lib/cpp/src/server/TThreadPoolServer.cpp new file mode 100644 index 000000000..0894cfa5f --- /dev/null +++ b/lib/cpp/src/server/TThreadPoolServer.cpp @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TThreadPoolServer.h" +#include "transport/TTransportException.h" +#include "concurrency/Thread.h" +#include "concurrency/ThreadManager.h" +#include <string> +#include <iostream> + +namespace apache { namespace thrift { namespace server { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::concurrency; +using namespace apache::thrift::protocol;; +using namespace apache::thrift::transport; + +class TThreadPoolServer::Task : public Runnable { + +public: + + Task(TThreadPoolServer &server, + shared_ptr<TProcessor> processor, + shared_ptr<TProtocol> input, + shared_ptr<TProtocol> output) : + server_(server), + processor_(processor), + input_(input), + output_(output) { + } + + ~Task() {} + + void run() { + boost::shared_ptr<TServerEventHandler> eventHandler = + server_.getEventHandler(); + if (eventHandler != NULL) { + eventHandler->clientBegin(input_, output_); + } + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + // This is reasonably expected, client didn't send a full request so just + // ignore him + // string errStr = string("TThreadPoolServer client died: ") + ttx.what(); + // GlobalOutput(errStr.c_str()); + } catch (TException& x) { + string errStr = string("TThreadPoolServer exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } catch (std::exception &x) { + string errStr = string("TThreadPoolServer, std::exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } + + if (eventHandler != NULL) { + eventHandler->clientEnd(input_, output_); + } + + try { + input_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer input close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + try { + output_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer output close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + + } + + private: + TServer& server_; + shared_ptr<TProcessor> processor_; + shared_ptr<TProtocol> input_; + shared_ptr<TProtocol> output_; + +}; + +TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor, + shared_ptr<TServerTransport> serverTransport, + shared_ptr<TTransportFactory> transportFactory, + shared_ptr<TProtocolFactory> protocolFactory, + shared_ptr<ThreadManager> threadManager) : + TServer(processor, serverTransport, transportFactory, protocolFactory), + threadManager_(threadManager), + stop_(false), timeout_(0) {} + +TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor, + shared_ptr<TServerTransport> serverTransport, + shared_ptr<TTransportFactory> inputTransportFactory, + shared_ptr<TTransportFactory> outputTransportFactory, + shared_ptr<TProtocolFactory> inputProtocolFactory, + shared_ptr<TProtocolFactory> outputProtocolFactory, + shared_ptr<ThreadManager> threadManager) : + TServer(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + threadManager_(threadManager), + stop_(false), timeout_(0) {} + + +TThreadPoolServer::~TThreadPoolServer() {} + +void TThreadPoolServer::serve() { + shared_ptr<TTransport> client; + shared_ptr<TTransport> inputTransport; + shared_ptr<TTransport> outputTransport; + shared_ptr<TProtocol> inputProtocol; + shared_ptr<TProtocol> outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer::run() listen(): ") + ttx.what(); + GlobalOutput(errStr.c_str()); + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + while (!stop_) { + try { + client.reset(); + inputTransport.reset(); + outputTransport.reset(); + inputProtocol.reset(); + outputProtocol.reset(); + + // Fetch client from server + client = serverTransport_->accept(); + + // Make IO transports + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + + // Add to threadmanager pool + threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(*this, processor_, inputProtocol, outputProtocol)), timeout_); + + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) { + string errStr = string("TThreadPoolServer: TServerTransport died on accept: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = string("TThreadPoolServer: Caught TException: ") + tx.what(); + GlobalOutput(errStr.c_str()); + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = "TThreadPoolServer: Unknown exception: " + s; + GlobalOutput(errStr.c_str()); + break; + } + } + + // If stopped manually, join the existing threads + if (stop_) { + try { + serverTransport_->close(); + threadManager_->join(); + } catch (TException &tx) { + string errStr = string("TThreadPoolServer: Exception shutting down: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + stop_ = false; + } + +} + +int64_t TThreadPoolServer::getTimeout() const { + return timeout_; +} + +void TThreadPoolServer::setTimeout(int64_t value) { + timeout_ = value; +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TThreadPoolServer.h b/lib/cpp/src/server/TThreadPoolServer.h new file mode 100644 index 000000000..7b7e90647 --- /dev/null +++ b/lib/cpp/src/server/TThreadPoolServer.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_ +#define _THRIFT_SERVER_TTHREADPOOLSERVER_H_ 1 + +#include <concurrency/ThreadManager.h> +#include <server/TServer.h> +#include <transport/TServerTransport.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::concurrency::ThreadManager; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransportFactory; + +class TThreadPoolServer : public TServer { + public: + class Task; + + TThreadPoolServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<ThreadManager> threadManager); + + TThreadPoolServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory, + boost::shared_ptr<ThreadManager> threadManager); + + virtual ~TThreadPoolServer(); + + virtual void serve(); + + virtual int64_t getTimeout() const; + + virtual void setTimeout(int64_t value); + + virtual void stop() { + stop_ = true; + serverTransport_->interrupt(); + } + + protected: + + boost::shared_ptr<ThreadManager> threadManager_; + + volatile bool stop_; + + volatile int64_t timeout_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_ diff --git a/lib/cpp/src/server/TThreadedServer.cpp b/lib/cpp/src/server/TThreadedServer.cpp new file mode 100644 index 000000000..cc30f8ff7 --- /dev/null +++ b/lib/cpp/src/server/TThreadedServer.cpp @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TThreadedServer.h" +#include "transport/TTransportException.h" +#include "concurrency/PosixThreadFactory.h" + +#include <string> +#include <iostream> +#include <pthread.h> +#include <unistd.h> + +namespace apache { namespace thrift { namespace server { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::concurrency; + +class TThreadedServer::Task: public Runnable { + +public: + + Task(TThreadedServer& server, + shared_ptr<TProcessor> processor, + shared_ptr<TProtocol> input, + shared_ptr<TProtocol> output) : + server_(server), + processor_(processor), + input_(input), + output_(output) { + } + + ~Task() {} + + void run() { + boost::shared_ptr<TServerEventHandler> eventHandler = + server_.getEventHandler(); + if (eventHandler != NULL) { + eventHandler->clientBegin(input_, output_); + } + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer client died: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } catch (TException& x) { + string errStr = string("TThreadedServer exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } catch (...) { + GlobalOutput("TThreadedServer uncaught exception."); + } + if (eventHandler != NULL) { + eventHandler->clientEnd(input_, output_); + } + + try { + input_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer input close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + try { + output_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer output close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + + // Remove this task from parent bookkeeping + { + Synchronized s(server_.tasksMonitor_); + server_.tasks_.erase(this); + if (server_.tasks_.empty()) { + server_.tasksMonitor_.notify(); + } + } + + } + + private: + TThreadedServer& server_; + friend class TThreadedServer; + + shared_ptr<TProcessor> processor_; + shared_ptr<TProtocol> input_; + shared_ptr<TProtocol> output_; +}; + + +TThreadedServer::TThreadedServer(shared_ptr<TProcessor> processor, + shared_ptr<TServerTransport> serverTransport, + shared_ptr<TTransportFactory> transportFactory, + shared_ptr<TProtocolFactory> protocolFactory): + TServer(processor, serverTransport, transportFactory, protocolFactory), + stop_(false) { + threadFactory_ = shared_ptr<PosixThreadFactory>(new PosixThreadFactory()); +} + +TThreadedServer::TThreadedServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<ThreadFactory> threadFactory): + TServer(processor, serverTransport, transportFactory, protocolFactory), + threadFactory_(threadFactory), + stop_(false) { +} + +TThreadedServer::~TThreadedServer() {} + +void TThreadedServer::serve() { + + shared_ptr<TTransport> client; + shared_ptr<TTransport> inputTransport; + shared_ptr<TTransport> outputTransport; + shared_ptr<TProtocol> inputProtocol; + shared_ptr<TProtocol> outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer::run() listen(): ") +ttx.what(); + GlobalOutput(errStr.c_str()); + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + while (!stop_) { + try { + client.reset(); + inputTransport.reset(); + outputTransport.reset(); + inputProtocol.reset(); + outputProtocol.reset(); + + // Fetch client from server + client = serverTransport_->accept(); + + // Make IO transports + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + + TThreadedServer::Task* task = new TThreadedServer::Task(*this, + processor_, + inputProtocol, + outputProtocol); + + // Create a task + shared_ptr<Runnable> runnable = + shared_ptr<Runnable>(task); + + // Create a thread for this task + shared_ptr<Thread> thread = + shared_ptr<Thread>(threadFactory_->newThread(runnable)); + + // Insert thread into the set of threads + { + Synchronized s(tasksMonitor_); + tasks_.insert(task); + } + + // Start the thread! + thread->start(); + + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) { + string errStr = string("TThreadedServer: TServerTransport died on accept: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = string("TThreadedServer: Caught TException: ") + tx.what(); + GlobalOutput(errStr.c_str()); + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = "TThreadedServer: Unknown exception: " + s; + GlobalOutput(errStr.c_str()); + break; + } + } + + // If stopped manually, make sure to close server transport + if (stop_) { + try { + serverTransport_->close(); + } catch (TException &tx) { + string errStr = string("TThreadedServer: Exception shutting down: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + try { + Synchronized s(tasksMonitor_); + while (!tasks_.empty()) { + tasksMonitor_.wait(); + } + } catch (TException &tx) { + string errStr = string("TThreadedServer: Exception joining workers: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + stop_ = false; + } + +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TThreadedServer.h b/lib/cpp/src/server/TThreadedServer.h new file mode 100644 index 000000000..4d0811aaa --- /dev/null +++ b/lib/cpp/src/server/TThreadedServer.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_ +#define _THRIFT_SERVER_TTHREADEDSERVER_H_ 1 + +#include <server/TServer.h> +#include <transport/TServerTransport.h> +#include <concurrency/Monitor.h> +#include <concurrency/Thread.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::TProcessor; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransportFactory; +using apache::thrift::concurrency::Monitor; +using apache::thrift::concurrency::ThreadFactory; + +class TThreadedServer : public TServer { + + public: + class Task; + + TThreadedServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory); + + TThreadedServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<ThreadFactory> threadFactory); + + virtual ~TThreadedServer(); + + virtual void serve(); + + void stop() { + stop_ = true; + serverTransport_->interrupt(); + } + + protected: + boost::shared_ptr<ThreadFactory> threadFactory_; + volatile bool stop_; + + Monitor tasksMonitor_; + std::set<Task*> tasks_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_ diff --git a/lib/cpp/src/transport/TBufferTransports.cpp b/lib/cpp/src/transport/TBufferTransports.cpp new file mode 100644 index 000000000..7a7e5e928 --- /dev/null +++ b/lib/cpp/src/transport/TBufferTransports.cpp @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cassert> +#include <algorithm> + +#include <transport/TBufferTransports.h> + +using std::string; + +namespace apache { namespace thrift { namespace transport { + + +uint32_t TBufferedTransport::readSlow(uint8_t* buf, uint32_t len) { + uint32_t want = len; + uint32_t have = rBound_ - rBase_; + + // We should only take the slow path if we can't satisfy the read + // with the data already in the buffer. + assert(have < want); + + // Copy out whatever we have. + if (have > 0) { + memcpy(buf, rBase_, have); + want -= have; + buf += have; + } + // Get more from underlying transport up to buffer size. + // Note that this makes a lot of sense if len < rBufSize_ + // and almost no sense otherwise. TODO(dreiss): Fix that + // case (possibly including some readv hotness). + setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_)); + + // Hand over whatever we have. + uint32_t give = std::min(want, static_cast<uint32_t>(rBound_ - rBase_)); + memcpy(buf, rBase_, give); + rBase_ += give; + want -= give; + + return (len - want); +} + +void TBufferedTransport::writeSlow(const uint8_t* buf, uint32_t len) { + uint32_t have_bytes = wBase_ - wBuf_.get(); + uint32_t space = wBound_ - wBase_; + // We should only take the slow path if we can't accomodate the write + // with the free space already in the buffer. + assert(wBound_ - wBase_ < static_cast<ptrdiff_t>(len)); + + // Now here's the tricky question: should we copy data from buf into our + // internal buffer and write it from there, or should we just write out + // the current internal buffer in one syscall and write out buf in another. + // If our currently buffered data plus buf is at least double our buffer + // size, we will have to do two syscalls no matter what (except in the + // degenerate case when our buffer is empty), so there is no use copying. + // Otherwise, there is sort of a sliding scale. If we have N-1 bytes + // buffered and need to write 2, it would be crazy to do two syscalls. + // On the other hand, if we have 2 bytes buffered and are writing 2N-3, + // we can save a syscall in the short term by loading up our buffer, writing + // it out, and copying the rest of the bytes into our buffer. Of course, + // if we get another 2-byte write, we haven't saved any syscalls at all, + // and have just copied nearly 2N bytes for nothing. Finding a perfect + // policy would require predicting the size of future writes, so we're just + // going to always eschew syscalls if we have less than 2N bytes to write. + + // The case where we have to do two syscalls. + // This case also covers the case where the buffer is empty, + // but it is clearer (I think) to think of it as two separate cases. + if ((have_bytes + len >= 2*wBufSize_) || (have_bytes == 0)) { + // TODO(dreiss): writev + if (have_bytes > 0) { + transport_->write(wBuf_.get(), have_bytes); + } + transport_->write(buf, len); + wBase_ = wBuf_.get(); + return; + } + + // Fill up our internal buffer for a write. + memcpy(wBase_, buf, space); + buf += space; + len -= space; + transport_->write(wBuf_.get(), wBufSize_); + + // Copy the rest into our buffer. + assert(len < wBufSize_); + memcpy(wBuf_.get(), buf, len); + wBase_ = wBuf_.get() + len; + return; +} + +const uint8_t* TBufferedTransport::borrowSlow(uint8_t* buf, uint32_t* len) { + // If the request is bigger than our buffer, we are hosed. + if (*len > rBufSize_) { + return NULL; + } + + // The number of bytes of data we have already. + uint32_t have = rBound_ - rBase_; + // The number of additional bytes we need from the underlying transport. + int32_t need = *len - have; + // The space from the start of the buffer to the end of our data. + uint32_t offset = rBound_ - rBuf_.get(); + assert(need > 0); + + // If we have less than half our buffer space available, shift the data + // we have down to the start. If the borrow is big compared to our buffer, + // this could be kind of a waste, but if the borrow is small, it frees up + // space at the end of our buffer to do a bigger single read from the + // underlying transport. Also, if our needs extend past the end of the + // buffer, we have to do a copy no matter what. + if ((offset > rBufSize_/2) || (offset + need > rBufSize_)) { + memmove(rBuf_.get(), rBase_, have); + setReadBuffer(rBuf_.get(), have); + } + + // First try to fill up the buffer. + uint32_t got = transport_->read(rBound_, rBufSize_ - have); + rBound_ += got; + need -= got; + + // If that fails, readAll until we get what we need. + if (need > 0) { + rBound_ += transport_->readAll(rBound_, need); + } + + *len = rBound_ - rBase_; + return rBase_; +} + +void TBufferedTransport::flush() { + // Write out any data waiting in the write buffer. + uint32_t have_bytes = wBase_ - wBuf_.get(); + if (have_bytes > 0) { + // Note that we reset wBase_ prior to the underlying write + // to ensure we're in a sane state (i.e. internal buffer cleaned) + // if the underlying write throws up an exception + wBase_ = wBuf_.get(); + transport_->write(wBuf_.get(), have_bytes); + } + + // Flush the underlying transport. + transport_->flush(); +} + + +uint32_t TFramedTransport::readSlow(uint8_t* buf, uint32_t len) { + uint32_t want = len; + uint32_t have = rBound_ - rBase_; + + // We should only take the slow path if we can't satisfy the read + // with the data already in the buffer. + assert(have < want); + + // Copy out whatever we have. + if (have > 0) { + memcpy(buf, rBase_, have); + want -= have; + buf += have; + } + + // Read another frame. + readFrame(); + + // TODO(dreiss): Should we warn when reads cross frames? + + // Hand over whatever we have. + uint32_t give = std::min(want, static_cast<uint32_t>(rBound_ - rBase_)); + memcpy(buf, rBase_, give); + rBase_ += give; + want -= give; + + return (len - want); +} + +void TFramedTransport::readFrame() { + // TODO(dreiss): Think about using readv here, even though it would + // result in (gasp) read-ahead. + + // Read the size of the next frame. + int32_t sz; + transport_->readAll((uint8_t*)&sz, sizeof(sz)); + sz = ntohl(sz); + + if (sz < 0) { + throw TTransportException("Frame size has negative value"); + } + + // Read the frame payload, and reset markers. + if (sz > static_cast<int32_t>(rBufSize_)) { + rBuf_.reset(new uint8_t[sz]); + rBufSize_ = sz; + } + transport_->readAll(rBuf_.get(), sz); + setReadBuffer(rBuf_.get(), sz); +} + +void TFramedTransport::writeSlow(const uint8_t* buf, uint32_t len) { + // Double buffer size until sufficient. + uint32_t have = wBase_ - wBuf_.get(); + while (wBufSize_ < len + have) { + wBufSize_ *= 2; + } + + // TODO(dreiss): Consider modifying this class to use malloc/free + // so we can use realloc here. + + // Allocate new buffer. + uint8_t* new_buf = new uint8_t[wBufSize_]; + + // Copy the old buffer to the new one. + memcpy(new_buf, wBuf_.get(), have); + + // Now point buf to the new one. + wBuf_.reset(new_buf); + wBase_ = wBuf_.get() + have; + wBound_ = wBuf_.get() + wBufSize_; + + // Copy the data into the new buffer. + memcpy(wBase_, buf, len); + wBase_ += len; +} + +void TFramedTransport::flush() { + int32_t sz_hbo, sz_nbo; + assert(wBufSize_ > sizeof(sz_nbo)); + + // Slip the frame size into the start of the buffer. + sz_hbo = wBase_ - (wBuf_.get() + sizeof(sz_nbo)); + sz_nbo = (int32_t)htonl((uint32_t)(sz_hbo)); + memcpy(wBuf_.get(), (uint8_t*)&sz_nbo, sizeof(sz_nbo)); + + if (sz_hbo > 0) { + // Note that we reset wBase_ (with a pad for the frame size) + // prior to the underlying write to ensure we're in a sane state + // (i.e. internal buffer cleaned) if the underlying write throws + // up an exception + wBase_ = wBuf_.get() + sizeof(sz_nbo); + + // Write size and frame body. + transport_->write(wBuf_.get(), sizeof(sz_nbo)+sz_hbo); + } + + // Flush the underlying transport. + transport_->flush(); +} + +const uint8_t* TFramedTransport::borrowSlow(uint8_t* buf, uint32_t* len) { + // Don't try to be clever with shifting buffers. + // If the fast path failed let the protocol use its slow path. + // Besides, who is going to try to borrow across messages? + return NULL; +} + + +void TMemoryBuffer::computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give) { + // Correct rBound_ so we can use the fast path in the future. + rBound_ = wBase_; + + // Decide how much to give. + uint32_t give = std::min(len, available_read()); + + *out_start = rBase_; + *out_give = give; + + // Preincrement rBase_ so the caller doesn't have to. + rBase_ += give; +} + +uint32_t TMemoryBuffer::readSlow(uint8_t* buf, uint32_t len) { + uint8_t* start; + uint32_t give; + computeRead(len, &start, &give); + + // Copy into the provided buffer. + memcpy(buf, start, give); + + return give; +} + +uint32_t TMemoryBuffer::readAppendToString(std::string& str, uint32_t len) { + // Don't get some stupid assertion failure. + if (buffer_ == NULL) { + return 0; + } + + uint8_t* start; + uint32_t give; + computeRead(len, &start, &give); + + // Append to the provided string. + str.append((char*)start, give); + + return give; +} + +void TMemoryBuffer::ensureCanWrite(uint32_t len) { + // Check available space + uint32_t avail = available_write(); + if (len <= avail) { + return; + } + + if (!owner_) { + throw TTransportException("Insufficient space in external MemoryBuffer"); + } + + // Grow the buffer as necessary. + while (len > avail) { + bufferSize_ *= 2; + wBound_ = buffer_ + bufferSize_; + avail = available_write(); + } + + // Allocate into a new pointer so we don't bork ours if it fails. + void* new_buffer = std::realloc(buffer_, bufferSize_); + if (new_buffer == NULL) { + throw TTransportException("Out of memory."); + } + + ptrdiff_t offset = (uint8_t*)new_buffer - buffer_; + buffer_ += offset; + rBase_ += offset; + rBound_ += offset; + wBase_ += offset; + wBound_ += offset; +} + +void TMemoryBuffer::writeSlow(const uint8_t* buf, uint32_t len) { + ensureCanWrite(len); + + // Copy into the buffer and increment wBase_. + memcpy(wBase_, buf, len); + wBase_ += len; +} + +void TMemoryBuffer::wroteBytes(uint32_t len) { + uint32_t avail = available_write(); + if (len > avail) { + throw TTransportException("Client wrote more bytes than size of buffer."); + } + wBase_ += len; +} + +const uint8_t* TMemoryBuffer::borrowSlow(uint8_t* buf, uint32_t* len) { + rBound_ = wBase_; + if (available_read() >= *len) { + *len = available_read(); + return rBase_; + } + return NULL; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TBufferTransports.h b/lib/cpp/src/transport/TBufferTransports.h new file mode 100644 index 000000000..1908205ff --- /dev/null +++ b/lib/cpp/src/transport/TBufferTransports.h @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ +#define _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ 1 + +#include <cstring> +#include "boost/scoped_array.hpp" + +#include <transport/TTransport.h> + +#ifdef __GNUC__ +#define TDB_LIKELY(val) (__builtin_expect((val), 1)) +#define TDB_UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define TDB_LIKELY(val) (val) +#define TDB_UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace transport { + + +/** + * Base class for all transports that use read/write buffers for performance. + * + * TBufferBase is designed to implement the fast-path "memcpy" style + * operations that work in the common case. It does so with small and + * (eventually) nonvirtual, inlinable methods. TBufferBase is an abstract + * class. Subclasses are expected to define the "slow path" operations + * that have to be done when the buffers are full or empty. + * + */ +class TBufferBase : public TTransport { + + public: + + /** + * Fast-path read. + * + * When we have enough data buffered to fulfill the read, we can satisfy it + * with a single memcpy, then adjust our internal pointers. If the buffer + * is empty, we call out to our slow path, implemented by a subclass. + * This method is meant to eventually be nonvirtual and inlinable. + */ + uint32_t read(uint8_t* buf, uint32_t len) { + uint8_t* new_rBase = rBase_ + len; + if (TDB_LIKELY(new_rBase <= rBound_)) { + std::memcpy(buf, rBase_, len); + rBase_ = new_rBase; + return len; + } + return readSlow(buf, len); + } + + /** + * Fast-path write. + * + * When we have enough empty space in our buffer to accomodate the write, we + * can satisfy it with a single memcpy, then adjust our internal pointers. + * If the buffer is full, we call out to our slow path, implemented by a + * subclass. This method is meant to eventually be nonvirtual and + * inlinable. + */ + void write(const uint8_t* buf, uint32_t len) { + uint8_t* new_wBase = wBase_ + len; + if (TDB_LIKELY(new_wBase <= wBound_)) { + std::memcpy(wBase_, buf, len); + wBase_ = new_wBase; + return; + } + writeSlow(buf, len); + } + + /** + * Fast-path borrow. A lot like the fast-path read. + */ + const uint8_t* borrow(uint8_t* buf, uint32_t* len) { + if (TDB_LIKELY(static_cast<ptrdiff_t>(*len) <= rBound_ - rBase_)) { + // With strict aliasing, writing to len shouldn't force us to + // refetch rBase_ from memory. TODO(dreiss): Verify this. + *len = rBound_ - rBase_; + return rBase_; + } + return borrowSlow(buf, len); + } + + /** + * Consume doesn't require a slow path. + */ + void consume(uint32_t len) { + if (TDB_LIKELY(static_cast<ptrdiff_t>(len) <= rBound_ - rBase_)) { + rBase_ += len; + } else { + throw TTransportException(TTransportException::BAD_ARGS, + "consume did not follow a borrow."); + } + } + + + protected: + + /// Slow path read. + virtual uint32_t readSlow(uint8_t* buf, uint32_t len) = 0; + + /// Slow path write. + virtual void writeSlow(const uint8_t* buf, uint32_t len) = 0; + + /** + * Slow path borrow. + * + * POSTCONDITION: return == NULL || rBound_ - rBase_ >= *len + */ + virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len) = 0; + + /** + * Trivial constructor. + * + * Initialize pointers safely. Constructing is not a very + * performance-sensitive operation, so it is okay to just leave it to + * the concrete class to set up pointers correctly. + */ + TBufferBase() + : rBase_(NULL) + , rBound_(NULL) + , wBase_(NULL) + , wBound_(NULL) + {} + + /// Convenience mutator for setting the read buffer. + void setReadBuffer(uint8_t* buf, uint32_t len) { + rBase_ = buf; + rBound_ = buf+len; + } + + /// Convenience mutator for setting the write buffer. + void setWriteBuffer(uint8_t* buf, uint32_t len) { + wBase_ = buf; + wBound_ = buf+len; + } + + virtual ~TBufferBase() {} + + /// Reads begin here. + uint8_t* rBase_; + /// Reads may extend to just before here. + uint8_t* rBound_; + + /// Writes begin here. + uint8_t* wBase_; + /// Writes may extend to just before here. + uint8_t* wBound_; +}; + + +/** + * Base class for all transport which wraps transport to new one. + */ +class TUnderlyingTransport : public TBufferBase { + public: + static const int DEFAULT_BUFFER_SIZE = 512; + + virtual bool peek() { + return (rBase_ < rBound_) || transport_->peek(); + } + + void open() { + transport_->open(); + } + + bool isOpen() { + return transport_->isOpen(); + } + + void close() { + flush(); + transport_->close(); + } + + boost::shared_ptr<TTransport> getUnderlyingTransport() { + return transport_; + } + + protected: + boost::shared_ptr<TTransport> transport_; + + uint32_t rBufSize_; + uint32_t wBufSize_; + boost::scoped_array<uint8_t> rBuf_; + boost::scoped_array<uint8_t> wBuf_; + + TUnderlyingTransport(boost::shared_ptr<TTransport> transport, uint32_t sz) + : transport_(transport) + , rBufSize_(sz) + , wBufSize_(sz) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} + + TUnderlyingTransport(boost::shared_ptr<TTransport> transport) + : transport_(transport) + , rBufSize_(DEFAULT_BUFFER_SIZE) + , wBufSize_(DEFAULT_BUFFER_SIZE) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} + + TUnderlyingTransport(boost::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz) + : transport_(transport) + , rBufSize_(rsz) + , wBufSize_(wsz) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} +}; + +/** + * Buffered transport. For reads it will read more data than is requested + * and will serve future data out of a local buffer. For writes, data is + * stored to an in memory buffer before being written out. + * + */ +class TBufferedTransport : public TUnderlyingTransport { + public: + + /// Use default buffer sizes. + TBufferedTransport(boost::shared_ptr<TTransport> transport) + : TUnderlyingTransport(transport) + { + initPointers(); + } + + /// Use specified buffer sizes. + TBufferedTransport(boost::shared_ptr<TTransport> transport, uint32_t sz) + : TUnderlyingTransport(transport, sz) + { + initPointers(); + } + + /// Use specified read and write buffer sizes. + TBufferedTransport(boost::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz) + : TUnderlyingTransport(transport, rsz, wsz) + { + initPointers(); + } + + virtual bool peek() { + /* shigin: see THRIFT-96 discussion */ + if (rBase_ == rBound_) { + setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_)); + } + return (rBound_ > rBase_); + } + virtual uint32_t readSlow(uint8_t* buf, uint32_t len); + + virtual void writeSlow(const uint8_t* buf, uint32_t len); + + void flush(); + + + /** + * The following behavior is currently implemented by TBufferedTransport, + * but that may change in a future version: + * 1/ If len is at most rBufSize_, borrow will never return NULL. + * Depending on the underlying transport, it could throw an exception + * or hang forever. + * 2/ Some borrow requests may copy bytes internally. However, + * if len is at most rBufSize_/2, none of the copied bytes + * will ever have to be copied again. For optimial performance, + * stay under this limit. + */ + virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + protected: + void initPointers() { + setReadBuffer(rBuf_.get(), 0); + setWriteBuffer(wBuf_.get(), wBufSize_); + // Write size never changes. + } +}; + + +/** + * Wraps a transport into a buffered one. + * + */ +class TBufferedTransportFactory : public TTransportFactory { + public: + TBufferedTransportFactory() {} + + virtual ~TBufferedTransportFactory() {} + + /** + * Wraps the transport into a buffered one. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TTransport>(new TBufferedTransport(trans)); + } + +}; + + +/** + * Framed transport. All writes go into an in-memory buffer until flush is + * called, at which point the transport writes the length of the entire + * binary chunk followed by the data payload. This allows the receiver on the + * other end to always do fixed-length reads. + * + */ +class TFramedTransport : public TUnderlyingTransport { + public: + + /// Use default buffer sizes. + TFramedTransport(boost::shared_ptr<TTransport> transport) + : TUnderlyingTransport(transport) + { + initPointers(); + } + + TFramedTransport(boost::shared_ptr<TTransport> transport, uint32_t sz) + : TUnderlyingTransport(transport, sz) + { + initPointers(); + } + + virtual uint32_t readSlow(uint8_t* buf, uint32_t len); + + virtual void writeSlow(const uint8_t* buf, uint32_t len); + + virtual void flush(); + + const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + protected: + /** + * Reads a frame of input from the underlying stream. + */ + void readFrame(); + + void initPointers() { + setReadBuffer(NULL, 0); + setWriteBuffer(wBuf_.get(), wBufSize_); + + // Pad the buffer so we can insert the size later. + int32_t pad = 0; + this->write((uint8_t*)&pad, sizeof(pad)); + } +}; + +/** + * Wraps a transport into a framed one. + * + */ +class TFramedTransportFactory : public TTransportFactory { + public: + TFramedTransportFactory() {} + + virtual ~TFramedTransportFactory() {} + + /** + * Wraps the transport into a framed one. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TTransport>(new TFramedTransport(trans)); + } + +}; + + +/** + * A memory buffer is a tranpsort that simply reads from and writes to an + * in memory buffer. Anytime you call write on it, the data is simply placed + * into a buffer, and anytime you call read, data is read from that buffer. + * + * The buffers are allocated using C constructs malloc,realloc, and the size + * doubles as necessary. We've considered using scoped + * + */ +class TMemoryBuffer : public TBufferBase { + private: + + // Common initialization done by all constructors. + void initCommon(uint8_t* buf, uint32_t size, bool owner, uint32_t wPos) { + if (buf == NULL && size != 0) { + assert(owner); + buf = (uint8_t*)std::malloc(size); + if (buf == NULL) { + throw TTransportException("Out of memory"); + } + } + + buffer_ = buf; + bufferSize_ = size; + + rBase_ = buffer_; + rBound_ = buffer_ + wPos; + // TODO(dreiss): Investigate NULL-ing this if !owner. + wBase_ = buffer_ + wPos; + wBound_ = buffer_ + bufferSize_; + + owner_ = owner; + + // rBound_ is really an artifact. In principle, it should always be + // equal to wBase_. We update it in a few places (computeRead, etc.). + } + + public: + static const uint32_t defaultSize = 1024; + + /** + * This enum specifies how a TMemoryBuffer should treat + * memory passed to it via constructors or resetBuffer. + * + * OBSERVE: + * TMemoryBuffer will simply store a pointer to the memory. + * It is the callers responsibility to ensure that the pointer + * remains valid for the lifetime of the TMemoryBuffer, + * and that it is properly cleaned up. + * Note that no data can be written to observed buffers. + * + * COPY: + * TMemoryBuffer will make an internal copy of the buffer. + * The caller has no responsibilities. + * + * TAKE_OWNERSHIP: + * TMemoryBuffer will become the "owner" of the buffer, + * and will be responsible for freeing it. + * The membory must have been allocated with malloc. + */ + enum MemoryPolicy + { OBSERVE = 1 + , COPY = 2 + , TAKE_OWNERSHIP = 3 + }; + + /** + * Construct a TMemoryBuffer with a default-sized buffer, + * owned by the TMemoryBuffer object. + */ + TMemoryBuffer() { + initCommon(NULL, defaultSize, true, 0); + } + + /** + * Construct a TMemoryBuffer with a buffer of a specified size, + * owned by the TMemoryBuffer object. + * + * @param sz The initial size of the buffer. + */ + TMemoryBuffer(uint32_t sz) { + initCommon(NULL, sz, true, 0); + } + + /** + * Construct a TMemoryBuffer with buf as its initial contents. + * + * @param buf The initial contents of the buffer. + * Note that, while buf is a non-const pointer, + * TMemoryBuffer will not write to it if policy == OBSERVE, + * so it is safe to const_cast<uint8_t*>(whatever). + * @param sz The size of @c buf. + * @param policy See @link MemoryPolicy @endlink . + */ + TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) { + if (buf == NULL && sz != 0) { + throw TTransportException(TTransportException::BAD_ARGS, + "TMemoryBuffer given null buffer with non-zero size."); + } + + switch (policy) { + case OBSERVE: + case TAKE_OWNERSHIP: + initCommon(buf, sz, policy == TAKE_OWNERSHIP, sz); + break; + case COPY: + initCommon(NULL, sz, true, 0); + this->write(buf, sz); + break; + default: + throw TTransportException(TTransportException::BAD_ARGS, + "Invalid MemoryPolicy for TMemoryBuffer"); + } + } + + ~TMemoryBuffer() { + if (owner_) { + std::free(buffer_); + } + } + + bool isOpen() { + return true; + } + + bool peek() { + return (rBase_ < wBase_); + } + + void open() {} + + void close() {} + + // TODO(dreiss): Make bufPtr const. + void getBuffer(uint8_t** bufPtr, uint32_t* sz) { + *bufPtr = rBase_; + *sz = wBase_ - rBase_; + } + + std::string getBufferAsString() { + if (buffer_ == NULL) { + return ""; + } + uint8_t* buf; + uint32_t sz; + getBuffer(&buf, &sz); + return std::string((char*)buf, (std::string::size_type)sz); + } + + void appendBufferToString(std::string& str) { + if (buffer_ == NULL) { + return; + } + uint8_t* buf; + uint32_t sz; + getBuffer(&buf, &sz); + str.append((char*)buf, sz); + } + + void resetBuffer(bool reset_capacity = false) { + if (reset_capacity) + { + assert(owner_); + + void* new_buffer = std::realloc(buffer_, defaultSize); + + if (new_buffer == NULL) { + throw TTransportException("Out of memory."); + } + + buffer_ = (uint8_t*) new_buffer; + bufferSize_ = defaultSize; + + wBound_ = buffer_ + bufferSize_; + } + + rBase_ = buffer_; + rBound_ = buffer_; + wBase_ = buffer_; + // It isn't safe to write into a buffer we don't own. + if (!owner_) { + wBound_ = wBase_; + bufferSize_ = 0; + } + } + + /// See constructor documentation. + void resetBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) { + // Use a variant of the copy-and-swap trick for assignment operators. + // This is sub-optimal in terms of performance for two reasons: + // 1/ The constructing and swapping of the (small) values + // in the temporary object takes some time, and is not necessary. + // 2/ If policy == COPY, we allocate the new buffer before + // freeing the old one, precluding the possibility of + // reusing that memory. + // I doubt that either of these problems could be optimized away, + // but the second is probably no a common case, and the first is minor. + // I don't expect resetBuffer to be a common operation, so I'm willing to + // bite the performance bullet to make the method this simple. + + // Construct the new buffer. + TMemoryBuffer new_buffer(buf, sz, policy); + // Move it into ourself. + this->swap(new_buffer); + // Our old self gets destroyed. + } + + std::string readAsString(uint32_t len) { + std::string str; + (void)readAppendToString(str, len); + return str; + } + + uint32_t readAppendToString(std::string& str, uint32_t len); + + void readEnd() { + if (rBase_ == wBase_) { + resetBuffer(); + } + } + + uint32_t available_read() const { + // Remember, wBase_ is the real rBound_. + return wBase_ - rBase_; + } + + uint32_t available_write() const { + return wBound_ - wBase_; + } + + // Returns a pointer to where the client can write data to append to + // the TMemoryBuffer, and ensures the buffer is big enough to accomodate a + // write of the provided length. The returned pointer is very convenient for + // passing to read(), recv(), or similar. You must call wroteBytes() as soon + // as data is written or the buffer will not be aware that data has changed. + uint8_t* getWritePtr(uint32_t len) { + ensureCanWrite(len); + return wBase_; + } + + // Informs the buffer that the client has written 'len' bytes into storage + // that had been provided by getWritePtr(). + void wroteBytes(uint32_t len); + + protected: + void swap(TMemoryBuffer& that) { + using std::swap; + swap(buffer_, that.buffer_); + swap(bufferSize_, that.bufferSize_); + + swap(rBase_, that.rBase_); + swap(rBound_, that.rBound_); + swap(wBase_, that.wBase_); + swap(wBound_, that.wBound_); + + swap(owner_, that.owner_); + } + + // Make sure there's at least 'len' bytes available for writing. + void ensureCanWrite(uint32_t len); + + // Compute the position and available data for reading. + void computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give); + + uint32_t readSlow(uint8_t* buf, uint32_t len); + + void writeSlow(const uint8_t* buf, uint32_t len); + + const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + // Data buffer + uint8_t* buffer_; + + // Allocated buffer size + uint32_t bufferSize_; + + // Is this object the owner of the buffer? + bool owner_; + + // Don't forget to update constrctors, initCommon, and swap if + // you add new members. +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ diff --git a/lib/cpp/src/transport/TFDTransport.cpp b/lib/cpp/src/transport/TFDTransport.cpp new file mode 100644 index 000000000..a042f8b74 --- /dev/null +++ b/lib/cpp/src/transport/TFDTransport.cpp @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cerrno> +#include <exception> + +#include <transport/TFDTransport.h> + +#include <unistd.h> + +using namespace std; + +namespace apache { namespace thrift { namespace transport { + +void TFDTransport::close() { + if (!isOpen()) { + return; + } + + int rv = ::close(fd_); + int errno_copy = errno; + fd_ = -1; + // Have to check uncaught_exception because this is called in the destructor. + if (rv < 0 && !std::uncaught_exception()) { + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::close()", + errno_copy); + } +} + +uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) { + ssize_t rv = ::read(fd_, buf, len); + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::read()", + errno_copy); + } + return rv; +} + +void TFDTransport::write(const uint8_t* buf, uint32_t len) { + while (len > 0) { + ssize_t rv = ::write(fd_, buf, len); + + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::write()", + errno_copy); + } else if (rv == 0) { + throw TTransportException(TTransportException::END_OF_FILE, + "TFDTransport::write()"); + } + + buf += rv; + len -= rv; + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TFDTransport.h b/lib/cpp/src/transport/TFDTransport.h new file mode 100644 index 000000000..bda5d82a9 --- /dev/null +++ b/lib/cpp/src/transport/TFDTransport.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TFDTRANSPORT_H_ 1 + +#include <string> +#include <sys/time.h> + +#include "TTransport.h" +#include "TServerSocket.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * Dead-simple wrapper around a file descriptor. + * + */ +class TFDTransport : public TTransport { + public: + enum ClosePolicy + { NO_CLOSE_ON_DESTROY = 0 + , CLOSE_ON_DESTROY = 1 + }; + + TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY) + : fd_(fd) + , close_policy_(close_policy) + {} + + ~TFDTransport() { + if (close_policy_ == CLOSE_ON_DESTROY) { + close(); + } + } + + bool isOpen() { return fd_ >= 0; } + + void open() {} + + void close(); + + uint32_t read(uint8_t* buf, uint32_t len); + + void write(const uint8_t* buf, uint32_t len); + + void setFD(int fd) { fd_ = fd; } + int getFD() { return fd_; } + + protected: + int fd_; + ClosePolicy close_policy_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TFileTransport.cpp b/lib/cpp/src/transport/TFileTransport.cpp new file mode 100644 index 000000000..f67b9e355 --- /dev/null +++ b/lib/cpp/src/transport/TFileTransport.cpp @@ -0,0 +1,953 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "TFileTransport.h" +#include "TTransportUtils.h" + +#include <pthread.h> +#ifdef HAVE_SYS_TIME_H +#include <sys/time.h> +#else +#include <time.h> +#endif +#include <fcntl.h> +#include <errno.h> +#include <unistd.h> +#ifdef HAVE_STRINGS_H +#include <strings.h> +#endif +#include <cstdlib> +#include <cstring> +#include <iostream> +#include <sys/stat.h> + +namespace apache { namespace thrift { namespace transport { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift::protocol; + +#ifndef HAVE_CLOCK_GETTIME + +/** + * Fake clock_gettime for systems like darwin + * + */ +#define CLOCK_REALTIME 0 +static int clock_gettime(int clk_id /*ignored*/, struct timespec *tp) { + struct timeval now; + + int rv = gettimeofday(&now, NULL); + if (rv != 0) { + return rv; + } + + tp->tv_sec = now.tv_sec; + tp->tv_nsec = now.tv_usec * 1000; + return 0; +} +#endif + +TFileTransport::TFileTransport(string path, bool readOnly) + : readState_() + , readBuff_(NULL) + , currentEvent_(NULL) + , readBuffSize_(DEFAULT_READ_BUFF_SIZE) + , readTimeout_(NO_TAIL_READ_TIMEOUT) + , chunkSize_(DEFAULT_CHUNK_SIZE) + , eventBufferSize_(DEFAULT_EVENT_BUFFER_SIZE) + , flushMaxUs_(DEFAULT_FLUSH_MAX_US) + , flushMaxBytes_(DEFAULT_FLUSH_MAX_BYTES) + , maxEventSize_(DEFAULT_MAX_EVENT_SIZE) + , maxCorruptedEvents_(DEFAULT_MAX_CORRUPTED_EVENTS) + , eofSleepTime_(DEFAULT_EOF_SLEEP_TIME_US) + , corruptedEventSleepTime_(DEFAULT_CORRUPTED_SLEEP_TIME_US) + , writerThreadId_(0) + , dequeueBuffer_(NULL) + , enqueueBuffer_(NULL) + , closing_(false) + , forceFlush_(false) + , filename_(path) + , fd_(0) + , bufferAndThreadInitialized_(false) + , offset_(0) + , lastBadChunk_(0) + , numCorruptedEventsInChunk_(0) + , readOnly_(readOnly) +{ + // initialize all the condition vars/mutexes + pthread_mutex_init(&mutex_, NULL); + pthread_cond_init(¬Full_, NULL); + pthread_cond_init(¬Empty_, NULL); + pthread_cond_init(&flushed_, NULL); + + openLogFile(); +} + +void TFileTransport::resetOutputFile(int fd, string filename, int64_t offset) { + filename_ = filename; + offset_ = offset; + + // check if current file is still open + if (fd_ > 0) { + // flush any events in the queue + flush(); + GlobalOutput.printf("error, current file (%s) not closed", filename_.c_str()); + if (-1 == ::close(fd_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: resetOutputFile() ::close() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy); + } + } + + if (fd) { + fd_ = fd; + } else { + // open file if the input fd is 0 + openLogFile(); + } +} + + +TFileTransport::~TFileTransport() { + // flush the buffer if a writer thread is active + if (writerThreadId_ > 0) { + // reduce the flush timeout so that closing is quicker + setFlushMaxUs(300*1000); + + // flush output buffer + flush(); + + // set state to closing + closing_ = true; + + // TODO: make sure event queue is empty + // currently only the write buffer is flushed + // we dont actually wait until the queue is empty. This shouldn't be a big + // deal in the common case because writing is quick + + pthread_join(writerThreadId_, NULL); + writerThreadId_ = 0; + } + + if (dequeueBuffer_) { + delete dequeueBuffer_; + dequeueBuffer_ = NULL; + } + + if (enqueueBuffer_) { + delete enqueueBuffer_; + enqueueBuffer_ = NULL; + } + + if (readBuff_) { + delete[] readBuff_; + readBuff_ = NULL; + } + + if (currentEvent_) { + delete currentEvent_; + currentEvent_ = NULL; + } + + // close logfile + if (fd_ > 0) { + if(-1 == ::close(fd_)) { + GlobalOutput.perror("TFileTransport: ~TFileTransport() ::close() ", errno); + } + } +} + +bool TFileTransport::initBufferAndWriteThread() { + if (bufferAndThreadInitialized_) { + T_ERROR("Trying to double-init TFileTransport"); + return false; + } + + if (writerThreadId_ == 0) { + if (pthread_create(&writerThreadId_, NULL, startWriterThread, (void *)this) != 0) { + T_ERROR("Could not create writer thread"); + return false; + } + } + + dequeueBuffer_ = new TFileTransportBuffer(eventBufferSize_); + enqueueBuffer_ = new TFileTransportBuffer(eventBufferSize_); + bufferAndThreadInitialized_ = true; + + return true; +} + +void TFileTransport::write(const uint8_t* buf, uint32_t len) { + if (readOnly_) { + throw TTransportException("TFileTransport: attempting to write to file opened readonly"); + } + + enqueueEvent(buf, len, false); +} + +void TFileTransport::enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush) { + // can't enqueue more events if file is going to close + if (closing_) { + return; + } + + // make sure that event size is valid + if ( (maxEventSize_ > 0) && (eventLen > maxEventSize_) ) { + T_ERROR("msg size is greater than max event size: %u > %u\n", eventLen, maxEventSize_); + return; + } + + if (eventLen == 0) { + T_ERROR("cannot enqueue an empty event"); + return; + } + + eventInfo* toEnqueue = new eventInfo(); + toEnqueue->eventBuff_ = (uint8_t *)std::malloc((sizeof(uint8_t) * eventLen) + 4); + // first 4 bytes is the event length + memcpy(toEnqueue->eventBuff_, (void*)(&eventLen), 4); + // actual event contents + memcpy(toEnqueue->eventBuff_ + 4, buf, eventLen); + toEnqueue->eventSize_ = eventLen + 4; + + // lock mutex + pthread_mutex_lock(&mutex_); + + // make sure that enqueue buffer is initialized and writer thread is running + if (!bufferAndThreadInitialized_) { + if (!initBufferAndWriteThread()) { + delete toEnqueue; + pthread_mutex_unlock(&mutex_); + return; + } + } + + // Can't enqueue while buffer is full + while (enqueueBuffer_->isFull()) { + pthread_cond_wait(¬Full_, &mutex_); + } + + // add to the buffer + if (!enqueueBuffer_->addEvent(toEnqueue)) { + delete toEnqueue; + pthread_mutex_unlock(&mutex_); + return; + } + + // signal anybody who's waiting for the buffer to be non-empty + pthread_cond_signal(¬Empty_); + + if (blockUntilFlush) { + pthread_cond_wait(&flushed_, &mutex_); + } + + // this really should be a loop where it makes sure it got flushed + // because condition variables can get triggered by the os for no reason + // it is probably a non-factor for the time being + pthread_mutex_unlock(&mutex_); +} + +bool TFileTransport::swapEventBuffers(struct timespec* deadline) { + pthread_mutex_lock(&mutex_); + if (deadline != NULL) { + // if we were handed a deadline time struct, do a timed wait + pthread_cond_timedwait(¬Empty_, &mutex_, deadline); + } else { + // just wait until the buffer gets an item + pthread_cond_wait(¬Empty_, &mutex_); + } + + bool swapped = false; + + // could be empty if we timed out + if (!enqueueBuffer_->isEmpty()) { + TFileTransportBuffer *temp = enqueueBuffer_; + enqueueBuffer_ = dequeueBuffer_; + dequeueBuffer_ = temp; + + swapped = true; + } + + // unlock the mutex and signal if required + pthread_mutex_unlock(&mutex_); + + if (swapped) { + pthread_cond_signal(¬Full_); + } + + return swapped; +} + + +void TFileTransport::writerThread() { + // open file if it is not open + if(!fd_) { + openLogFile(); + } + + // set the offset to the correct value (EOF) + try { + seekToEnd(); + } catch (TException &te) { + } + + // throw away any partial events + offset_ += readState_.lastDispatchPtr_; + ftruncate(fd_, offset_); + readState_.resetAllValues(); + + // Figure out the next time by which a flush must take place + + struct timespec ts_next_flush; + getNextFlushTime(&ts_next_flush); + uint32_t unflushed = 0; + + while(1) { + // this will only be true when the destructor is being invoked + if(closing_) { + // empty out both the buffers + if (enqueueBuffer_->isEmpty() && dequeueBuffer_->isEmpty()) { + if (-1 == ::close(fd_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: writerThread() ::close() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy); + } + // just be safe and sync to disk + fsync(fd_); + fd_ = 0; + pthread_exit(NULL); + return; + } + } + + if (swapEventBuffers(&ts_next_flush)) { + eventInfo* outEvent; + while (NULL != (outEvent = dequeueBuffer_->getNext())) { + if (!outEvent) { + T_DEBUG_L(1, "Got an empty event"); + return; + } + + // sanity check on event + if ((maxEventSize_ > 0) && (outEvent->eventSize_ > maxEventSize_)) { + T_ERROR("msg size is greater than max event size: %u > %u\n", outEvent->eventSize_, maxEventSize_); + continue; + } + + // If chunking is required, then make sure that msg does not cross chunk boundary + if ((outEvent->eventSize_ > 0) && (chunkSize_ != 0)) { + + // event size must be less than chunk size + if(outEvent->eventSize_ > chunkSize_) { + T_ERROR("TFileTransport: event size(%u) is greater than chunk size(%u): skipping event", + outEvent->eventSize_, chunkSize_); + continue; + } + + int64_t chunk1 = offset_/chunkSize_; + int64_t chunk2 = (offset_ + outEvent->eventSize_ - 1)/chunkSize_; + + // if adding this event will cross a chunk boundary, pad the chunk with zeros + if (chunk1 != chunk2) { + // refetch the offset to keep in sync + offset_ = lseek(fd_, 0, SEEK_CUR); + int32_t padding = (int32_t)((offset_/chunkSize_ + 1)*chunkSize_ - offset_); + + uint8_t zeros[padding]; + bzero(zeros, padding); + if (-1 == ::write(fd_, zeros, padding)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: writerThread() error while padding zeros ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while padding zeros", errno_copy); + } + unflushed += padding; + offset_ += padding; + } + } + + // write the dequeued event to the file + if (outEvent->eventSize_ > 0) { + if (-1 == ::write(fd_, outEvent->eventBuff_, outEvent->eventSize_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: error while writing event ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while writing event", errno_copy); + } + + unflushed += outEvent->eventSize_; + offset_ += outEvent->eventSize_; + } + } + dequeueBuffer_->reset(); + } + + bool flushTimeElapsed = false; + struct timespec current_time; + clock_gettime(CLOCK_REALTIME, ¤t_time); + + if (current_time.tv_sec > ts_next_flush.tv_sec || + (current_time.tv_sec == ts_next_flush.tv_sec && current_time.tv_nsec > ts_next_flush.tv_nsec)) { + flushTimeElapsed = true; + getNextFlushTime(&ts_next_flush); + } + + // couple of cases from which a flush could be triggered + if ((flushTimeElapsed && unflushed > 0) || + unflushed > flushMaxBytes_ || + forceFlush_) { + + // sync (force flush) file to disk + fsync(fd_); + unflushed = 0; + + // notify anybody waiting for flush completion + forceFlush_ = false; + pthread_cond_broadcast(&flushed_); + } + } +} + +void TFileTransport::flush() { + // file must be open for writing for any flushing to take place + if (writerThreadId_ <= 0) { + return; + } + // wait for flush to take place + pthread_mutex_lock(&mutex_); + + forceFlush_ = true; + + while (forceFlush_) { + pthread_cond_wait(&flushed_, &mutex_); + } + + pthread_mutex_unlock(&mutex_); +} + + +uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TEOFException(); + } + have += get; + } + + return have; +} + +uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) { + // check if there an event is ready to be read + if (!currentEvent_) { + currentEvent_ = readEvent(); + } + + // did not manage to read an event from the file. This could have happened + // if the timeout expired or there was some other error + if (!currentEvent_) { + return 0; + } + + // read as much of the current event as possible + int32_t remaining = currentEvent_->eventSize_ - currentEvent_->eventBuffPos_; + if (remaining <= (int32_t)len) { + // copy over anything thats remaining + if (remaining > 0) { + memcpy(buf, + currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_, + remaining); + } + delete(currentEvent_); + currentEvent_ = NULL; + return remaining; + } + + // read as much as possible + memcpy(buf, currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_, len); + currentEvent_->eventBuffPos_ += len; + return len; +} + +eventInfo* TFileTransport::readEvent() { + int readTries = 0; + + if (!readBuff_) { + readBuff_ = new uint8_t[readBuffSize_]; + } + + while (1) { + // read from the file if read buffer is exhausted + if (readState_.bufferPtr_ == readState_.bufferLen_) { + // advance the offset pointer + offset_ += readState_.bufferLen_; + readState_.bufferLen_ = ::read(fd_, readBuff_, readBuffSize_); + // if (readState_.bufferLen_) { + // T_DEBUG_L(1, "Amount read: %u (offset: %lu)", readState_.bufferLen_, offset_); + // } + readState_.bufferPtr_ = 0; + readState_.lastDispatchPtr_ = 0; + + // read error + if (readState_.bufferLen_ == -1) { + readState_.resetAllValues(); + GlobalOutput("TFileTransport: error while reading from file"); + throw TTransportException("TFileTransport: error while reading from file"); + } else if (readState_.bufferLen_ == 0) { // EOF + // wait indefinitely if there is no timeout + if (readTimeout_ == TAIL_READ_TIMEOUT) { + usleep(eofSleepTime_); + continue; + } else if (readTimeout_ == NO_TAIL_READ_TIMEOUT) { + // reset state + readState_.resetState(0); + return NULL; + } else if (readTimeout_ > 0) { + // timeout already expired once + if (readTries > 0) { + readState_.resetState(0); + return NULL; + } else { + usleep(readTimeout_ * 1000); + readTries++; + continue; + } + } + } + } + + readTries = 0; + + // attempt to read an event from the buffer + while(readState_.bufferPtr_ < readState_.bufferLen_) { + if (readState_.readingSize_) { + if(readState_.eventSizeBuffPos_ == 0) { + if ( (offset_ + readState_.bufferPtr_)/chunkSize_ != + ((offset_ + readState_.bufferPtr_ + 3)/chunkSize_)) { + // skip one byte towards chunk boundary + // T_DEBUG_L(1, "Skipping a byte"); + readState_.bufferPtr_++; + continue; + } + } + + readState_.eventSizeBuff_[readState_.eventSizeBuffPos_++] = + readBuff_[readState_.bufferPtr_++]; + if (readState_.eventSizeBuffPos_ == 4) { + // 0 length event indicates padding + if (*((uint32_t *)(readState_.eventSizeBuff_)) == 0) { + // T_DEBUG_L(1, "Got padding"); + readState_.resetState(readState_.lastDispatchPtr_); + continue; + } + // got a valid event + readState_.readingSize_ = false; + if (readState_.event_) { + delete(readState_.event_); + } + readState_.event_ = new eventInfo(); + readState_.event_->eventSize_ = *((uint32_t *)(readState_.eventSizeBuff_)); + + // check if the event is corrupted and perform recovery if required + if (isEventCorrupted()) { + performRecovery(); + // start from the top + break; + } + } + } else { + if (!readState_.event_->eventBuff_) { + readState_.event_->eventBuff_ = new uint8_t[readState_.event_->eventSize_]; + readState_.event_->eventBuffPos_ = 0; + } + // take either the entire event or the remaining bytes in the buffer + int reclaimBuffer = min((uint32_t)(readState_.bufferLen_ - readState_.bufferPtr_), + readState_.event_->eventSize_ - readState_.event_->eventBuffPos_); + + // copy data from read buffer into event buffer + memcpy(readState_.event_->eventBuff_ + readState_.event_->eventBuffPos_, + readBuff_ + readState_.bufferPtr_, + reclaimBuffer); + + // increment position ptrs + readState_.event_->eventBuffPos_ += reclaimBuffer; + readState_.bufferPtr_ += reclaimBuffer; + + // check if the event has been read in full + if (readState_.event_->eventBuffPos_ == readState_.event_->eventSize_) { + // set the completed event to the current event + eventInfo* completeEvent = readState_.event_; + completeEvent->eventBuffPos_ = 0; + + readState_.event_ = NULL; + readState_.resetState(readState_.bufferPtr_); + + // exit criteria + return completeEvent; + } + } + } + + } +} + +bool TFileTransport::isEventCorrupted() { + // an error is triggered if: + if ( (maxEventSize_ > 0) && (readState_.event_->eventSize_ > maxEventSize_)) { + // 1. Event size is larger than user-speficied max-event size + T_ERROR("Read corrupt event. Event size(%u) greater than max event size (%u)", + readState_.event_->eventSize_, maxEventSize_); + return true; + } else if (readState_.event_->eventSize_ > chunkSize_) { + // 2. Event size is larger than chunk size + T_ERROR("Read corrupt event. Event size(%u) greater than chunk size (%u)", + readState_.event_->eventSize_, chunkSize_); + return true; + } else if( ((offset_ + readState_.bufferPtr_ - 4)/chunkSize_) != + ((offset_ + readState_.bufferPtr_ + readState_.event_->eventSize_ - 1)/chunkSize_) ) { + // 3. size indicates that event crosses chunk boundary + T_ERROR("Read corrupt event. Event crosses chunk boundary. Event size:%u Offset:%ld", + readState_.event_->eventSize_, offset_ + readState_.bufferPtr_ + 4); + return true; + } + + return false; +} + +void TFileTransport::performRecovery() { + // perform some kickass recovery + uint32_t curChunk = getCurChunk(); + if (lastBadChunk_ == curChunk) { + numCorruptedEventsInChunk_++; + } else { + lastBadChunk_ = curChunk; + numCorruptedEventsInChunk_ = 1; + } + + if (numCorruptedEventsInChunk_ < maxCorruptedEvents_) { + // maybe there was an error in reading the file from disk + // seek to the beginning of chunk and try again + seekToChunk(curChunk); + } else { + + // just skip ahead to the next chunk if we not already at the last chunk + if (curChunk != (getNumChunks() - 1)) { + seekToChunk(curChunk + 1); + } else if (readTimeout_ == TAIL_READ_TIMEOUT) { + // if tailing the file, wait until there is enough data to start + // the next chunk + while(curChunk == (getNumChunks() - 1)) { + usleep(DEFAULT_CORRUPTED_SLEEP_TIME_US); + } + seekToChunk(curChunk + 1); + } else { + // pretty hosed at this stage, rewind the file back to the last successful + // point and punt on the error + readState_.resetState(readState_.lastDispatchPtr_); + currentEvent_ = NULL; + char errorMsg[1024]; + sprintf(errorMsg, "TFileTransport: log file corrupted at offset: %lu", + offset_ + readState_.lastDispatchPtr_); + GlobalOutput(errorMsg); + throw TTransportException(errorMsg); + } + } + +} + +void TFileTransport::seekToChunk(int32_t chunk) { + if (fd_ <= 0) { + throw TTransportException("File not open"); + } + + int32_t numChunks = getNumChunks(); + + // file is empty, seeking to chunk is pointless + if (numChunks == 0) { + return; + } + + // negative indicates reverse seek (from the end) + if (chunk < 0) { + chunk += numChunks; + } + + // too large a value for reverse seek, just seek to beginning + if (chunk < 0) { + T_DEBUG("Incorrect value for reverse seek. Seeking to beginning...", chunk) + chunk = 0; + } + + // cannot seek past EOF + bool seekToEnd = false; + uint32_t minEndOffset = 0; + if (chunk >= numChunks) { + T_DEBUG("Trying to seek past EOF. Seeking to EOF instead..."); + seekToEnd = true; + chunk = numChunks - 1; + // this is the min offset to process events till + minEndOffset = lseek(fd_, 0, SEEK_END); + } + + off_t newOffset = off_t(chunk) * chunkSize_; + offset_ = lseek(fd_, newOffset, SEEK_SET); + readState_.resetAllValues(); + currentEvent_ = NULL; + if (offset_ == -1) { + GlobalOutput("TFileTransport: lseek error in seekToChunk"); + throw TTransportException("TFileTransport: lseek error in seekToChunk"); + } + + // seek to EOF if user wanted to go to last chunk + if (seekToEnd) { + uint32_t oldReadTimeout = getReadTimeout(); + setReadTimeout(NO_TAIL_READ_TIMEOUT); + // keep on reading unti the last event at point of seekChunk call + while (readEvent() && ((offset_ + readState_.bufferPtr_) < minEndOffset)) {}; + setReadTimeout(oldReadTimeout); + } + +} + +void TFileTransport::seekToEnd() { + seekToChunk(getNumChunks()); +} + +uint32_t TFileTransport::getNumChunks() { + if (fd_ <= 0) { + return 0; + } + + struct stat f_info; + int rv = fstat(fd_, &f_info); + + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFileTransport::getNumChunks() (fstat)", + errno_copy); + } + + if (f_info.st_size > 0) { + return ((f_info.st_size)/chunkSize_) + 1; + } + + // empty file has no chunks + return 0; +} + +uint32_t TFileTransport::getCurChunk() { + return offset_/chunkSize_; +} + +// Utility Functions +void TFileTransport::openLogFile() { + mode_t mode = readOnly_ ? S_IRUSR | S_IRGRP | S_IROTH : S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH; + int flags = readOnly_ ? O_RDONLY : O_RDWR | O_CREAT | O_APPEND; + fd_ = ::open(filename_.c_str(), flags, mode); + offset_ = 0; + + // make sure open call was successful + if(fd_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: openLogFile() ::open() file: " + filename_, errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, filename_, errno_copy); + } + +} + +void TFileTransport::getNextFlushTime(struct timespec* ts_next_flush) { + clock_gettime(CLOCK_REALTIME, ts_next_flush); + ts_next_flush->tv_nsec += (flushMaxUs_ % 1000000) * 1000; + if (ts_next_flush->tv_nsec > 1000000000) { + ts_next_flush->tv_nsec -= 1000000000; + ts_next_flush->tv_sec += 1; + } + ts_next_flush->tv_sec += flushMaxUs_ / 1000000; +} + +TFileTransportBuffer::TFileTransportBuffer(uint32_t size) + : bufferMode_(WRITE) + , writePoint_(0) + , readPoint_(0) + , size_(size) +{ + buffer_ = new eventInfo*[size]; +} + +TFileTransportBuffer::~TFileTransportBuffer() { + if (buffer_) { + for (uint32_t i = 0; i < writePoint_; i++) { + delete buffer_[i]; + } + delete[] buffer_; + buffer_ = NULL; + } +} + +bool TFileTransportBuffer::addEvent(eventInfo *event) { + if (bufferMode_ == READ) { + GlobalOutput("Trying to write to a buffer in read mode"); + } + if (writePoint_ < size_) { + buffer_[writePoint_++] = event; + return true; + } else { + // buffer is full + return false; + } +} + +eventInfo* TFileTransportBuffer::getNext() { + if (bufferMode_ == WRITE) { + bufferMode_ = READ; + } + if (readPoint_ < writePoint_) { + return buffer_[readPoint_++]; + } else { + // no more entries + return NULL; + } +} + +void TFileTransportBuffer::reset() { + if (bufferMode_ == WRITE || writePoint_ > readPoint_) { + T_DEBUG("Resetting a buffer with unread entries"); + } + // Clean up the old entries + for (uint32_t i = 0; i < writePoint_; i++) { + delete buffer_[i]; + } + bufferMode_ = WRITE; + writePoint_ = 0; + readPoint_ = 0; +} + +bool TFileTransportBuffer::isFull() { + return writePoint_ == size_; +} + +bool TFileTransportBuffer::isEmpty() { + return writePoint_ == 0; +} + +TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor, + shared_ptr<TProtocolFactory> protocolFactory, + shared_ptr<TFileReaderTransport> inputTransport): + processor_(processor), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory), + inputTransport_(inputTransport) { + + // default the output transport to a null transport (common case) + outputTransport_ = shared_ptr<TNullTransport>(new TNullTransport()); +} + +TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor, + shared_ptr<TProtocolFactory> inputProtocolFactory, + shared_ptr<TProtocolFactory> outputProtocolFactory, + shared_ptr<TFileReaderTransport> inputTransport): + processor_(processor), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory), + inputTransport_(inputTransport) { + + // default the output transport to a null transport (common case) + outputTransport_ = shared_ptr<TNullTransport>(new TNullTransport()); +} + +TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor, + shared_ptr<TProtocolFactory> protocolFactory, + shared_ptr<TFileReaderTransport> inputTransport, + shared_ptr<TTransport> outputTransport): + processor_(processor), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory), + inputTransport_(inputTransport), + outputTransport_(outputTransport) {}; + +void TFileProcessor::process(uint32_t numEvents, bool tail) { + shared_ptr<TProtocol> inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_); + shared_ptr<TProtocol> outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_); + + // set the read timeout to 0 if tailing is required + int32_t oldReadTimeout = inputTransport_->getReadTimeout(); + if (tail) { + // save old read timeout so it can be restored + inputTransport_->setReadTimeout(TFileTransport::TAIL_READ_TIMEOUT); + } + + uint32_t numProcessed = 0; + while(1) { + // bad form to use exceptions for flow control but there is really + // no other way around it + try { + processor_->process(inputProtocol, outputProtocol); + numProcessed++; + if ( (numEvents > 0) && (numProcessed == numEvents)) { + return; + } + } catch (TEOFException& teof) { + if (!tail) { + break; + } + } catch (TException &te) { + cerr << te.what() << endl; + break; + } + } + + // restore old read timeout + if (tail) { + inputTransport_->setReadTimeout(oldReadTimeout); + } + +} + +void TFileProcessor::processChunk() { + shared_ptr<TProtocol> inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_); + shared_ptr<TProtocol> outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_); + + uint32_t curChunk = inputTransport_->getCurChunk(); + + while(1) { + // bad form to use exceptions for flow control but there is really + // no other way around it + try { + processor_->process(inputProtocol, outputProtocol); + if (curChunk != inputTransport_->getCurChunk()) { + break; + } + } catch (TEOFException& teof) { + break; + } catch (TException &te) { + cerr << te.what() << endl; + break; + } + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TFileTransport.h b/lib/cpp/src/transport/TFileTransport.h new file mode 100644 index 000000000..fbaf2cd0d --- /dev/null +++ b/lib/cpp/src/transport/TFileTransport.h @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TFILETRANSPORT_H_ +#define _THRIFT_TRANSPORT_TFILETRANSPORT_H_ 1 + +#include "TTransport.h" +#include "Thrift.h" +#include "TProcessor.h" + +#include <string> +#include <stdio.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +using apache::thrift::TProcessor; +using apache::thrift::protocol::TProtocolFactory; + +// Data pertaining to a single event +typedef struct eventInfo { + uint8_t* eventBuff_; + uint32_t eventSize_; + uint32_t eventBuffPos_; + + eventInfo():eventBuff_(NULL), eventSize_(0), eventBuffPos_(0){}; + ~eventInfo() { + if (eventBuff_) { + delete[] eventBuff_; + } + } +} eventInfo; + +// information about current read state +typedef struct readState { + eventInfo* event_; + + // keep track of event size + uint8_t eventSizeBuff_[4]; + uint8_t eventSizeBuffPos_; + bool readingSize_; + + // read buffer variables + int32_t bufferPtr_; + int32_t bufferLen_; + + // last successful dispatch point + int32_t lastDispatchPtr_; + + void resetState(uint32_t lastDispatchPtr) { + readingSize_ = true; + eventSizeBuffPos_ = 0; + lastDispatchPtr_ = lastDispatchPtr; + } + + void resetAllValues() { + resetState(0); + bufferPtr_ = 0; + bufferLen_ = 0; + if (event_) { + delete(event_); + } + event_ = 0; + } + + readState() { + event_ = 0; + resetAllValues(); + } + + ~readState() { + if (event_) { + delete(event_); + } + } + +} readState; + +/** + * TFileTransportBuffer - buffer class used by TFileTransport for queueing up events + * to be written to disk. Should be used in the following way: + * 1) Buffer created + * 2) Buffer written to (addEvent) + * 3) Buffer read from (getNext) + * 4) Buffer reset (reset) + * 5) Go back to 2, or destroy buffer + * + * The buffer should never be written to after it is read from, unless it is reset first. + * Note: The above rules are enforced mainly for debugging its sole client TFileTransport + * which uses the buffer in this way. + * + */ +class TFileTransportBuffer { + public: + TFileTransportBuffer(uint32_t size); + ~TFileTransportBuffer(); + + bool addEvent(eventInfo *event); + eventInfo* getNext(); + void reset(); + bool isFull(); + bool isEmpty(); + + private: + TFileTransportBuffer(); // should not be used + + enum mode { + WRITE, + READ + }; + mode bufferMode_; + + uint32_t writePoint_; + uint32_t readPoint_; + uint32_t size_; + eventInfo** buffer_; +}; + +/** + * Abstract interface for transports used to read files + */ +class TFileReaderTransport : virtual public TTransport { + public: + virtual int32_t getReadTimeout() = 0; + virtual void setReadTimeout(int32_t readTimeout) = 0; + + virtual uint32_t getNumChunks() = 0; + virtual uint32_t getCurChunk() = 0; + virtual void seekToChunk(int32_t chunk) = 0; + virtual void seekToEnd() = 0; +}; + +/** + * Abstract interface for transports used to write files + */ +class TFileWriterTransport : virtual public TTransport { + public: + virtual uint32_t getChunkSize() = 0; + virtual void setChunkSize(uint32_t chunkSize) = 0; +}; + +/** + * File implementation of a transport. Reads and writes are done to a + * file on disk. + * + */ +class TFileTransport : public TFileReaderTransport, + public TFileWriterTransport { + public: + TFileTransport(std::string path, bool readOnly=false); + ~TFileTransport(); + + // TODO: what is the correct behaviour for this? + // the log file is generally always open + bool isOpen() { + return true; + } + + void write(const uint8_t* buf, uint32_t len); + void flush(); + + uint32_t readAll(uint8_t* buf, uint32_t len); + uint32_t read(uint8_t* buf, uint32_t len); + + // log-file specific functions + void seekToChunk(int32_t chunk); + void seekToEnd(); + uint32_t getNumChunks(); + uint32_t getCurChunk(); + + // for changing the output file + void resetOutputFile(int fd, std::string filename, int64_t offset); + + // Setter/Getter functions for user-controllable options + void setReadBuffSize(uint32_t readBuffSize) { + if (readBuffSize) { + readBuffSize_ = readBuffSize; + } + } + uint32_t getReadBuffSize() { + return readBuffSize_; + } + + static const int32_t TAIL_READ_TIMEOUT = -1; + static const int32_t NO_TAIL_READ_TIMEOUT = 0; + void setReadTimeout(int32_t readTimeout) { + readTimeout_ = readTimeout; + } + int32_t getReadTimeout() { + return readTimeout_; + } + + void setChunkSize(uint32_t chunkSize) { + if (chunkSize) { + chunkSize_ = chunkSize; + } + } + uint32_t getChunkSize() { + return chunkSize_; + } + + void setEventBufferSize(uint32_t bufferSize) { + if (bufferAndThreadInitialized_) { + GlobalOutput("Cannot change the buffer size after writer thread started"); + return; + } + eventBufferSize_ = bufferSize; + } + + uint32_t getEventBufferSize() { + return eventBufferSize_; + } + + void setFlushMaxUs(uint32_t flushMaxUs) { + if (flushMaxUs) { + flushMaxUs_ = flushMaxUs; + } + } + uint32_t getFlushMaxUs() { + return flushMaxUs_; + } + + void setFlushMaxBytes(uint32_t flushMaxBytes) { + if (flushMaxBytes) { + flushMaxBytes_ = flushMaxBytes; + } + } + uint32_t getFlushMaxBytes() { + return flushMaxBytes_; + } + + void setMaxEventSize(uint32_t maxEventSize) { + maxEventSize_ = maxEventSize; + } + uint32_t getMaxEventSize() { + return maxEventSize_; + } + + void setMaxCorruptedEvents(uint32_t maxCorruptedEvents) { + maxCorruptedEvents_ = maxCorruptedEvents; + } + uint32_t getMaxCorruptedEvents() { + return maxCorruptedEvents_; + } + + void setEofSleepTimeUs(uint32_t eofSleepTime) { + if (eofSleepTime) { + eofSleepTime_ = eofSleepTime; + } + } + uint32_t getEofSleepTimeUs() { + return eofSleepTime_; + } + + private: + // helper functions for writing to a file + void enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush); + bool swapEventBuffers(struct timespec* deadline); + bool initBufferAndWriteThread(); + + // control for writer thread + static void* startWriterThread(void* ptr) { + (((TFileTransport*)ptr)->writerThread()); + return 0; + } + void writerThread(); + + // helper functions for reading from a file + eventInfo* readEvent(); + + // event corruption-related functions + bool isEventCorrupted(); + void performRecovery(); + + // Utility functions + void openLogFile(); + void getNextFlushTime(struct timespec* ts_next_flush); + + // Class variables + readState readState_; + uint8_t* readBuff_; + eventInfo* currentEvent_; + + uint32_t readBuffSize_; + static const uint32_t DEFAULT_READ_BUFF_SIZE = 1 * 1024 * 1024; + + int32_t readTimeout_; + static const int32_t DEFAULT_READ_TIMEOUT_MS = 200; + + // size of chunks that file will be split up into + uint32_t chunkSize_; + static const uint32_t DEFAULT_CHUNK_SIZE = 16 * 1024 * 1024; + + // size of event buffers + uint32_t eventBufferSize_; + static const uint32_t DEFAULT_EVENT_BUFFER_SIZE = 10000; + + // max number of microseconds that can pass without flushing + uint32_t flushMaxUs_; + static const uint32_t DEFAULT_FLUSH_MAX_US = 3000000; + + // max number of bytes that can be written without flushing + uint32_t flushMaxBytes_; + static const uint32_t DEFAULT_FLUSH_MAX_BYTES = 1000 * 1024; + + // max event size + uint32_t maxEventSize_; + static const uint32_t DEFAULT_MAX_EVENT_SIZE = 0; + + // max number of corrupted events per chunk + uint32_t maxCorruptedEvents_; + static const uint32_t DEFAULT_MAX_CORRUPTED_EVENTS = 0; + + // sleep duration when EOF is hit + uint32_t eofSleepTime_; + static const uint32_t DEFAULT_EOF_SLEEP_TIME_US = 500 * 1000; + + // sleep duration when a corrupted event is encountered + uint32_t corruptedEventSleepTime_; + static const uint32_t DEFAULT_CORRUPTED_SLEEP_TIME_US = 1 * 1000 * 1000; + + // writer thread id + pthread_t writerThreadId_; + + // buffers to hold data before it is flushed. Each element of the buffer stores a msg that + // needs to be written to the file. The buffers are swapped by the writer thread. + TFileTransportBuffer *dequeueBuffer_; + TFileTransportBuffer *enqueueBuffer_; + + // conditions used to block when the buffer is full or empty + pthread_cond_t notFull_, notEmpty_; + volatile bool closing_; + + // To keep track of whether the buffer has been flushed + pthread_cond_t flushed_; + volatile bool forceFlush_; + + // Mutex that is grabbed when enqueueing and swapping the read/write buffers + pthread_mutex_t mutex_; + + // File information + std::string filename_; + int fd_; + + // Whether the writer thread and buffers have been initialized + bool bufferAndThreadInitialized_; + + // Offset within the file + off_t offset_; + + // event corruption information + uint32_t lastBadChunk_; + uint32_t numCorruptedEventsInChunk_; + + bool readOnly_; +}; + +// Exception thrown when EOF is hit +class TEOFException : public TTransportException { + public: + TEOFException(): + TTransportException(TTransportException::END_OF_FILE) {}; +}; + + +// wrapper class to process events from a file containing thrift events +class TFileProcessor { + public: + /** + * Constructor that defaults output transport to null transport + * + * @param processor processes log-file events + * @param protocolFactory protocol factory + * @param inputTransport file transport + */ + TFileProcessor(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<TFileReaderTransport> inputTransport); + + TFileProcessor(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory, + boost::shared_ptr<TFileReaderTransport> inputTransport); + + /** + * Constructor + * + * @param processor processes log-file events + * @param protocolFactory protocol factory + * @param inputTransport input file transport + * @param output output transport + */ + TFileProcessor(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<TFileReaderTransport> inputTransport, + boost::shared_ptr<TTransport> outputTransport); + + /** + * processes events from the file + * + * @param numEvents number of events to process (0 for unlimited) + * @param tail tails the file if true + */ + void process(uint32_t numEvents, bool tail); + + /** + * process events until the end of the chunk + * + */ + void processChunk(); + + private: + boost::shared_ptr<TProcessor> processor_; + boost::shared_ptr<TProtocolFactory> inputProtocolFactory_; + boost::shared_ptr<TProtocolFactory> outputProtocolFactory_; + boost::shared_ptr<TFileReaderTransport> inputTransport_; + boost::shared_ptr<TTransport> outputTransport_; +}; + + +}}} // apache::thrift::transport + +#endif // _THRIFT_TRANSPORT_TFILETRANSPORT_H_ diff --git a/lib/cpp/src/transport/THttpClient.cpp b/lib/cpp/src/transport/THttpClient.cpp new file mode 100644 index 000000000..59f233968 --- /dev/null +++ b/lib/cpp/src/transport/THttpClient.cpp @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cstdlib> +#include <sstream> + +#include "THttpClient.h" +#include "TSocket.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +/** + * Http client implementation. + * + */ + +// Yeah, yeah, hacky to put these here, I know. +static const char* CRLF = "\r\n"; +static const int CRLF_LEN = 2; + +THttpClient::THttpClient(boost::shared_ptr<TTransport> transport, string host, string path) : + transport_(transport), + host_(host), + path_(path), + readHeaders_(true), + chunked_(false), + chunkedDone_(false), + chunkSize_(0), + contentLength_(0), + httpBuf_(NULL), + httpPos_(0), + httpBufLen_(0), + httpBufSize_(1024) { + init(); +} + +THttpClient::THttpClient(string host, int port, string path) : + host_(host), + path_(path), + readHeaders_(true), + chunked_(false), + chunkedDone_(false), + chunkSize_(0), + contentLength_(0), + httpBuf_(NULL), + httpPos_(0), + httpBufLen_(0), + httpBufSize_(1024) { + transport_ = boost::shared_ptr<TTransport>(new TSocket(host, port)); + init(); +} + +void THttpClient::init() { + httpBuf_ = (char*)std::malloc(httpBufSize_+1); + if (httpBuf_ == NULL) { + throw TTransportException("Out of memory."); + } + httpBuf_[httpBufLen_] = '\0'; +} + +THttpClient::~THttpClient() { + if (httpBuf_ != NULL) { + std::free(httpBuf_); + } +} + +uint32_t THttpClient::read(uint8_t* buf, uint32_t len) { + if (readBuffer_.available_read() == 0) { + readBuffer_.resetBuffer(); + uint32_t got = readMoreData(); + if (got == 0) { + return 0; + } + } + return readBuffer_.read(buf, len); +} + +void THttpClient::readEnd() { + // Read any pending chunked data (footers etc.) + if (chunked_) { + while (!chunkedDone_) { + readChunked(); + } + } +} + +uint32_t THttpClient::readMoreData() { + // Get more data! + refill(); + + if (readHeaders_) { + readHeaders(); + } + + if (chunked_) { + return readChunked(); + } else { + return readContent(contentLength_); + } +} + +uint32_t THttpClient::readChunked() { + uint32_t length = 0; + + char* line = readLine(); + uint32_t chunkSize = parseChunkSize(line); + if (chunkSize == 0) { + readChunkedFooters(); + } else { + // Read data content + length += readContent(chunkSize); + // Read trailing CRLF after content + readLine(); + } + return length; +} + +void THttpClient::readChunkedFooters() { + // End of data, read footer lines until a blank one appears + while (true) { + char* line = readLine(); + if (strlen(line) == 0) { + chunkedDone_ = true; + break; + } + } +} + +uint32_t THttpClient::parseChunkSize(char* line) { + char* semi = strchr(line, ';'); + if (semi != NULL) { + *semi = '\0'; + } + int size = 0; + sscanf(line, "%x", &size); + return (uint32_t)size; +} + +uint32_t THttpClient::readContent(uint32_t size) { + uint32_t need = size; + while (need > 0) { + uint32_t avail = httpBufLen_ - httpPos_; + if (avail == 0) { + // We have given all the data, reset position to head of the buffer + httpPos_ = 0; + httpBufLen_ = 0; + refill(); + + // Now have available however much we read + avail = httpBufLen_; + } + uint32_t give = avail; + if (need < give) { + give = need; + } + readBuffer_.write((uint8_t*)(httpBuf_+httpPos_), give); + httpPos_ += give; + need -= give; + } + return size; +} + +char* THttpClient::readLine() { + while (true) { + char* eol = NULL; + + eol = strstr(httpBuf_+httpPos_, CRLF); + + // No CRLF yet? + if (eol == NULL) { + // Shift whatever we have now to front and refill + shift(); + refill(); + } else { + // Return pointer to next line + *eol = '\0'; + char* line = httpBuf_+httpPos_; + httpPos_ = (eol-httpBuf_) + CRLF_LEN; + return line; + } + } + +} + +void THttpClient::shift() { + if (httpBufLen_ > httpPos_) { + // Shift down remaining data and read more + uint32_t length = httpBufLen_ - httpPos_; + memmove(httpBuf_, httpBuf_+httpPos_, length); + httpBufLen_ = length; + } else { + httpBufLen_ = 0; + } + httpPos_ = 0; + httpBuf_[httpBufLen_] = '\0'; +} + +void THttpClient::refill() { + uint32_t avail = httpBufSize_ - httpBufLen_; + if (avail <= (httpBufSize_ / 4)) { + httpBufSize_ *= 2; + httpBuf_ = (char*)std::realloc(httpBuf_, httpBufSize_+1); + if (httpBuf_ == NULL) { + throw TTransportException("Out of memory."); + } + } + + // Read more data + uint32_t got = transport_->read((uint8_t*)(httpBuf_+httpBufLen_), httpBufSize_-httpBufLen_); + httpBufLen_ += got; + httpBuf_[httpBufLen_] = '\0'; + + if (got == 0) { + throw TTransportException("Could not refill buffer"); + } +} + +void THttpClient::readHeaders() { + // Initialize headers state variables + contentLength_ = 0; + chunked_ = false; + chunkedDone_ = false; + chunkSize_ = 0; + + // Control state flow + bool statusLine = true; + bool finished = false; + + // Loop until headers are finished + while (true) { + char* line = readLine(); + + if (strlen(line) == 0) { + if (finished) { + readHeaders_ = false; + return; + } else { + // Must have been an HTTP 100, keep going for another status line + statusLine = true; + } + } else { + if (statusLine) { + statusLine = false; + finished = parseStatusLine(line); + } else { + parseHeader(line); + } + } + } +} + +bool THttpClient::parseStatusLine(char* status) { + char* http = status; + + char* code = strchr(http, ' '); + if (code == NULL) { + throw TTransportException(string("Bad Status: ") + status); + } + + *code = '\0'; + while (*(code++) == ' '); + + char* msg = strchr(code, ' '); + if (msg == NULL) { + throw TTransportException(string("Bad Status: ") + status); + } + *msg = '\0'; + + if (strcmp(code, "200") == 0) { + // HTTP 200 = OK, we got the response + return true; + } else if (strcmp(code, "100") == 0) { + // HTTP 100 = continue, just keep reading + return false; + } else { + throw TTransportException(string("Bad Status: ") + status); + } +} + +void THttpClient::parseHeader(char* header) { + char* colon = strchr(header, ':'); + if (colon == NULL) { + return; + } + uint32_t sz = colon - header; + char* value = colon+1; + + if (strncmp(header, "Transfer-Encoding", sz) == 0) { + if (strstr(value, "chunked") != NULL) { + chunked_ = true; + } + } else if (strncmp(header, "Content-Length", sz) == 0) { + chunked_ = false; + contentLength_ = atoi(value); + } +} + +void THttpClient::write(const uint8_t* buf, uint32_t len) { + writeBuffer_.write(buf, len); +} + +void THttpClient::flush() { + // Fetch the contents of the write buffer + uint8_t* buf; + uint32_t len; + writeBuffer_.getBuffer(&buf, &len); + + // Construct the HTTP header + std::ostringstream h; + h << + "POST " << path_ << " HTTP/1.1" << CRLF << + "Host: " << host_ << CRLF << + "Content-Type: application/x-thrift" << CRLF << + "Content-Length: " << len << CRLF << + "Accept: application/x-thrift" << CRLF << + "User-Agent: C++/THttpClient" << CRLF << + CRLF; + string header = h.str(); + + // Write the header, then the data, then flush + transport_->write((const uint8_t*)header.c_str(), header.size()); + transport_->write(buf, len); + transport_->flush(); + + // Reset the buffer and header variables + writeBuffer_.resetBuffer(); + readHeaders_ = true; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/THttpClient.h b/lib/cpp/src/transport/THttpClient.h new file mode 100644 index 000000000..f4be4c1a6 --- /dev/null +++ b/lib/cpp/src/transport/THttpClient.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_ +#define _THRIFT_TRANSPORT_THTTPCLIENT_H_ 1 + +#include <transport/TBufferTransports.h> + +namespace apache { namespace thrift { namespace transport { + +/** + * HTTP client implementation of the thrift transport. This was irritating + * to write, but the alternatives in C++ land are daunting. Linking CURL + * requires 23 dynamic libraries last time I checked (WTF?!?). All we have + * here is a VERY basic HTTP/1.1 client which supports HTTP 100 Continue, + * chunked transfer encoding, keepalive, etc. Tested against Apache. + * + */ +class THttpClient : public TTransport { + public: + THttpClient(boost::shared_ptr<TTransport> transport, std::string host, std::string path=""); + + THttpClient(std::string host, int port, std::string path=""); + + virtual ~THttpClient(); + + void open() { + transport_->open(); + } + + bool isOpen() { + return transport_->isOpen(); + } + + bool peek() { + return transport_->peek(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void readEnd(); + + void write(const uint8_t* buf, uint32_t len); + + void flush(); + + private: + void init(); + + protected: + + boost::shared_ptr<TTransport> transport_; + + TMemoryBuffer writeBuffer_; + TMemoryBuffer readBuffer_; + + std::string host_; + std::string path_; + + bool readHeaders_; + bool chunked_; + bool chunkedDone_; + uint32_t chunkSize_; + uint32_t contentLength_; + + char* httpBuf_; + uint32_t httpPos_; + uint32_t httpBufLen_; + uint32_t httpBufSize_; + + uint32_t readMoreData(); + char* readLine(); + + void readHeaders(); + void parseHeader(char* header); + bool parseStatusLine(char* status); + + uint32_t readChunked(); + void readChunkedFooters(); + uint32_t parseChunkSize(char* line); + + uint32_t readContent(uint32_t size); + + void refill(); + void shift(); + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_ diff --git a/lib/cpp/src/transport/TServerSocket.cpp b/lib/cpp/src/transport/TServerSocket.cpp new file mode 100644 index 000000000..9b47aa539 --- /dev/null +++ b/lib/cpp/src/transport/TServerSocket.cpp @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cstring> +#include <sys/socket.h> +#include <sys/poll.h> +#include <sys/types.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <netdb.h> +#include <fcntl.h> +#include <errno.h> + +#include "TSocket.h" +#include "TServerSocket.h" +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +using namespace std; +using boost::shared_ptr; + +TServerSocket::TServerSocket(int port) : + port_(port), + serverSocket_(-1), + acceptBacklog_(1024), + sendTimeout_(0), + recvTimeout_(0), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + intSock1_(-1), + intSock2_(-1) {} + +TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) : + port_(port), + serverSocket_(-1), + acceptBacklog_(1024), + sendTimeout_(sendTimeout), + recvTimeout_(recvTimeout), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + intSock1_(-1), + intSock2_(-1) {} + +TServerSocket::~TServerSocket() { + close(); +} + +void TServerSocket::setSendTimeout(int sendTimeout) { + sendTimeout_ = sendTimeout; +} + +void TServerSocket::setRecvTimeout(int recvTimeout) { + recvTimeout_ = recvTimeout; +} + +void TServerSocket::setRetryLimit(int retryLimit) { + retryLimit_ = retryLimit; +} + +void TServerSocket::setRetryDelay(int retryDelay) { + retryDelay_ = retryDelay; +} + +void TServerSocket::setTcpSendBuffer(int tcpSendBuffer) { + tcpSendBuffer_ = tcpSendBuffer; +} + +void TServerSocket::setTcpRecvBuffer(int tcpRecvBuffer) { + tcpRecvBuffer_ = tcpRecvBuffer; +} + +void TServerSocket::listen() { + int sv[2]; + if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TServerSocket::listen() socketpair() ", errno); + intSock1_ = -1; + intSock2_ = -1; + } else { + intSock1_ = sv[1]; + intSock2_ = sv[0]; + } + + struct addrinfo hints, *res, *res0; + int error; + char port[sizeof("65536") + 1]; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + // Wildcard address + error = getaddrinfo(NULL, port, &hints, &res0); + if (error) { + GlobalOutput.printf("getaddrinfo %d: %s", error, gai_strerror(error)); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for server socket."); + } + + // Pick the ipv6 address first since ipv4 addresses can be mapped + // into ipv6 space. + for (res = res0; res; res = res->ai_next) { + if (res->ai_family == AF_INET6 || res->ai_next == NULL) + break; + } + + serverSocket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (serverSocket_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not create server socket.", errno_copy); + } + + // Set reusaddress to prevent 2MSL delay on accept + int one = 1; + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_REUSEADDR ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_REUSEADDR", errno_copy); + } + + // Set TCP buffer sizes + if (tcpSendBuffer_ > 0) { + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_SNDBUF, + &tcpSendBuffer_, sizeof(tcpSendBuffer_))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_SNDBUF ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_SNDBUF", errno_copy); + } + } + + if (tcpRecvBuffer_ > 0) { + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_RCVBUF, + &tcpRecvBuffer_, sizeof(tcpRecvBuffer_))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_RCVBUF ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_RCVBUF", errno_copy); + } + } + + // Defer accept + #ifdef TCP_DEFER_ACCEPT + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, TCP_DEFER_ACCEPT, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_DEFER_ACCEPT ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_DEFER_ACCEPT", errno_copy); + } + #endif // #ifdef TCP_DEFER_ACCEPT + + #ifdef IPV6_V6ONLY + int zero = 0; + if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY, + &zero, sizeof(zero))) { + GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ", errno); + } + #endif // #ifdef IPV6_V6ONLY + + // Turn linger off, don't want to block on calls to close + struct linger ling = {0, 0}; + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER, + &ling, sizeof(ling))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_LINGER ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_LINGER", errno_copy); + } + + // TCP Nodelay, speed over bandwidth + if (-1 == setsockopt(serverSocket_, IPPROTO_TCP, TCP_NODELAY, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_NODELAY ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_NODELAY", errno_copy); + } + + // Set NONBLOCK on the accept socket + int flags = fcntl(serverSocket_, F_GETFL, 0); + if (flags == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() fcntl() F_GETFL ", errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + + if (-1 == fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() fcntl() O_NONBLOCK ", errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + + // prepare the port information + // we may want to try to bind more than once, since SO_REUSEADDR doesn't + // always seem to work. The client can configure the retry variables. + int retries = 0; + do { + if (0 == bind(serverSocket_, res->ai_addr, res->ai_addrlen)) { + break; + } + + // use short circuit evaluation here to only sleep if we need to + } while ((retries++ < retryLimit_) && (sleep(retryDelay_) == 0)); + + // free addrinfo + freeaddrinfo(res0); + + // throw an error if we failed to bind properly + if (retries > retryLimit_) { + char errbuf[1024]; + sprintf(errbuf, "TServerSocket::listen() BIND %d", port_); + GlobalOutput(errbuf); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not bind"); + } + + // Call listen + if (-1 == ::listen(serverSocket_, acceptBacklog_)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not listen", errno_copy); + } + + // The socket is now listening! +} + +shared_ptr<TTransport> TServerSocket::acceptImpl() { + if (serverSocket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening"); + } + + struct pollfd fds[2]; + + int maxEintrs = 5; + int numEintrs = 0; + + while (true) { + std::memset(fds, 0 , sizeof(fds)); + fds[0].fd = serverSocket_; + fds[0].events = POLLIN; + if (intSock2_ >= 0) { + fds[1].fd = intSock2_; + fds[1].events = POLLIN; + } + int ret = poll(fds, 2, -1); + + if (ret < 0) { + // error cases + if (errno == EINTR && (numEintrs++ < maxEintrs)) { + // EINTR needs to be handled manually and we can tolerate + // a certain number + continue; + } + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() poll() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } else if (ret > 0) { + // Check for an interrupt signal + if (intSock2_ >= 0 && (fds[1].revents & POLLIN)) { + int8_t buf; + if (-1 == recv(intSock2_, &buf, sizeof(int8_t), 0)) { + GlobalOutput.perror("TServerSocket::acceptImpl() recv() interrupt ", errno); + } + throw TTransportException(TTransportException::INTERRUPTED); + } + + // Check for the actual server socket being ready + if (fds[0].revents & POLLIN) { + break; + } + } else { + GlobalOutput("TServerSocket::acceptImpl() poll 0"); + throw TTransportException(TTransportException::UNKNOWN); + } + } + + struct sockaddr_storage clientAddress; + int size = sizeof(clientAddress); + int clientSocket = ::accept(serverSocket_, + (struct sockaddr *) &clientAddress, + (socklen_t *) &size); + + if (clientSocket < 0) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() ::accept() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "accept()", errno_copy); + } + + // Make sure client socket is blocking + int flags = fcntl(clientSocket, F_GETFL, 0); + if (flags == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_GETFL ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_GETFL)", errno_copy); + } + + if (-1 == fcntl(clientSocket, F_SETFL, flags & ~O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_SETFL ~O_NONBLOCK ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_SETFL)", errno_copy); + } + + shared_ptr<TSocket> client(new TSocket(clientSocket)); + if (sendTimeout_ > 0) { + client->setSendTimeout(sendTimeout_); + } + if (recvTimeout_ > 0) { + client->setRecvTimeout(recvTimeout_); + } + + return client; +} + +void TServerSocket::interrupt() { + if (intSock1_ >= 0) { + int8_t byte = 0; + if (-1 == send(intSock1_, &byte, sizeof(int8_t), 0)) { + GlobalOutput.perror("TServerSocket::interrupt() send() ", errno); + } + } +} + +void TServerSocket::close() { + if (serverSocket_ >= 0) { + shutdown(serverSocket_, SHUT_RDWR); + ::close(serverSocket_); + } + if (intSock1_ >= 0) { + ::close(intSock1_); + } + if (intSock2_ >= 0) { + ::close(intSock2_); + } + serverSocket_ = -1; + intSock1_ = -1; + intSock2_ = -1; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TServerSocket.h b/lib/cpp/src/transport/TServerSocket.h new file mode 100644 index 000000000..a6be01737 --- /dev/null +++ b/lib/cpp/src/transport/TServerSocket.h @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ +#define _THRIFT_TRANSPORT_TSERVERSOCKET_H_ 1 + +#include "TServerTransport.h" +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +class TSocket; + +/** + * Server socket implementation of TServerTransport. Wrapper around a unix + * socket listen and accept calls. + * + */ +class TServerSocket : public TServerTransport { + public: + TServerSocket(int port); + TServerSocket(int port, int sendTimeout, int recvTimeout); + + ~TServerSocket(); + + void setSendTimeout(int sendTimeout); + void setRecvTimeout(int recvTimeout); + + void setRetryLimit(int retryLimit); + void setRetryDelay(int retryDelay); + + void setTcpSendBuffer(int tcpSendBuffer); + void setTcpRecvBuffer(int tcpRecvBuffer); + + void listen(); + void close(); + + void interrupt(); + + protected: + boost::shared_ptr<TTransport> acceptImpl(); + + private: + int port_; + int serverSocket_; + int acceptBacklog_; + int sendTimeout_; + int recvTimeout_; + int retryLimit_; + int retryDelay_; + int tcpSendBuffer_; + int tcpRecvBuffer_; + + int intSock1_; + int intSock2_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ diff --git a/lib/cpp/src/transport/TServerTransport.h b/lib/cpp/src/transport/TServerTransport.h new file mode 100644 index 000000000..40bbc6c78 --- /dev/null +++ b/lib/cpp/src/transport/TServerTransport.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ 1 + +#include "TTransport.h" +#include "TTransportException.h" +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +/** + * Server transport framework. A server needs to have some facility for + * creating base transports to read/write from. + * + */ +class TServerTransport { + public: + virtual ~TServerTransport() {} + + /** + * Starts the server transport listening for new connections. Prior to this + * call most transports will not return anything when accept is called. + * + * @throws TTransportException if we were unable to listen + */ + virtual void listen() {} + + /** + * Gets a new dynamically allocated transport object and passes it to the + * caller. Note that it is the explicit duty of the caller to free the + * allocated object. The returned TTransport object must always be in the + * opened state. NULL should never be returned, instead an Exception should + * always be thrown. + * + * @return A new TTransport object + * @throws TTransportException if there is an error + */ + boost::shared_ptr<TTransport> accept() { + boost::shared_ptr<TTransport> result = acceptImpl(); + if (result == NULL) { + throw TTransportException("accept() may not return NULL"); + } + return result; + } + + /** + * For "smart" TServerTransport implementations that work in a multi + * threaded context this can be used to break out of an accept() call. + * It is expected that the transport will throw a TTransportException + * with the interrupted error code. + */ + virtual void interrupt() {} + + /** + * Closes this transport such that future calls to accept will do nothing. + */ + virtual void close() = 0; + + protected: + TServerTransport() {} + + /** + * Subclasses should implement this function for accept. + * + * @return A newly allocated TTransport object + * @throw TTransportException If an error occurs + */ + virtual boost::shared_ptr<TTransport> acceptImpl() = 0; + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TShortReadTransport.h b/lib/cpp/src/transport/TShortReadTransport.h new file mode 100644 index 000000000..3df8a57ca --- /dev/null +++ b/lib/cpp/src/transport/TShortReadTransport.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ 1 + +#include <cstdlib> + +#include <transport/TTransport.h> + +namespace apache { namespace thrift { namespace transport { namespace test { + +/** + * This class is only meant for testing. It wraps another transport. + * Calls to read are passed through with some probability. Otherwise, + * the read amount is randomly reduced before being passed through. + * + */ +class TShortReadTransport : public TTransport { + public: + TShortReadTransport(boost::shared_ptr<TTransport> transport, double full_prob) + : transport_(transport) + , fullProb_(full_prob) + {} + + bool isOpen() { + return transport_->isOpen(); + } + + bool peek() { + return transport_->peek(); + } + + void open() { + transport_->open(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len) { + if (len == 0) { + return 0; + } + + if (rand()/(double)RAND_MAX >= fullProb_) { + len = 1 + rand()%len; + } + return transport_->read(buf, len); + } + + void write(const uint8_t* buf, uint32_t len) { + transport_->write(buf, len); + } + + void flush() { + transport_->flush(); + } + + const uint8_t* borrow(uint8_t* buf, uint32_t* len) { + return transport_->borrow(buf, len); + } + + void consume(uint32_t len) { + return transport_->consume(len); + } + + boost::shared_ptr<TTransport> getUnderlyingTransport() { + return transport_; + } + + protected: + boost::shared_ptr<TTransport> transport_; + double fullProb_; +}; + +}}}} // apache::thrift::transport::test + +#endif // #ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TSimpleFileTransport.cpp b/lib/cpp/src/transport/TSimpleFileTransport.cpp new file mode 100644 index 000000000..e58a57430 --- /dev/null +++ b/lib/cpp/src/transport/TSimpleFileTransport.cpp @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TSimpleFileTransport.h" + +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> + +namespace apache { namespace thrift { namespace transport { + +TSimpleFileTransport:: +TSimpleFileTransport(const std::string& path, bool read, bool write) + : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) { + int flags = 0; + if (read && write) { + flags = O_RDWR; + } else if (read) { + flags = O_RDONLY; + } else if (write) { + flags = O_WRONLY; + } else { + throw TTransportException("Neither READ nor WRITE specified"); + } + if (write) { + flags |= O_CREAT | O_APPEND; + } + int fd = ::open(path.c_str(), + flags, + S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH); + if (fd < 0) { + throw TTransportException("failed to open file for writing: " + path); + } + setFD(fd); + open(); +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSimpleFileTransport.h b/lib/cpp/src/transport/TSimpleFileTransport.h new file mode 100644 index 000000000..6cc52ea1a --- /dev/null +++ b/lib/cpp/src/transport/TSimpleFileTransport.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ 1 + +#include "TFDTransport.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * Dead-simple wrapper around a file. + * + * Writeable files are opened with O_CREAT and O_APPEND + */ +class TSimpleFileTransport : public TFDTransport { + public: + TSimpleFileTransport(const std::string& path, + bool read = true, + bool write = false); +}; + +}}} // apache::thrift::transport + +#endif // _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ diff --git a/lib/cpp/src/transport/TSocket.cpp b/lib/cpp/src/transport/TSocket.cpp new file mode 100644 index 000000000..3395dabdc --- /dev/null +++ b/lib/cpp/src/transport/TSocket.cpp @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <config.h> +#include <cstring> +#include <sstream> +#include <sys/socket.h> +#include <sys/poll.h> +#include <sys/types.h> +#include <arpa/inet.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <netdb.h> +#include <unistd.h> +#include <errno.h> +#include <fcntl.h> + +#include "concurrency/Monitor.h" +#include "TSocket.h" +#include "TTransportException.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +// Global var to track total socket sys calls +uint32_t g_socket_syscalls = 0; + +/** + * TSocket implementation. + * + */ + +TSocket::TSocket(string host, int port) : + host_(host), + port_(port), + socket_(-1), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::TSocket() : + host_(""), + port_(0), + socket_(-1), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::TSocket(int socket) : + host_(""), + port_(0), + socket_(socket), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::~TSocket() { + close(); +} + +bool TSocket::isOpen() { + return (socket_ >= 0); +} + +bool TSocket::peek() { + if (!isOpen()) { + return false; + } + uint8_t buf; + int r = recv(socket_, &buf, 1, MSG_PEEK); + if (r == -1) { + int errno_copy = errno; + #ifdef __FreeBSD__ + /* shigin: + * freebsd returns -1 and ECONNRESET if socket was closed by + * the other side + */ + if (errno_copy == ECONNRESET) + { + close(); + return false; + } + #endif + GlobalOutput.perror("TSocket::peek() recv() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "recv()", errno_copy); + } + return (r > 0); +} + +void TSocket::openConnection(struct addrinfo *res) { + if (isOpen()) { + throw TTransportException(TTransportException::ALREADY_OPEN); + } + + socket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (socket_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy); + } + + // Send timeout + if (sendTimeout_ > 0) { + setSendTimeout(sendTimeout_); + } + + // Recv timeout + if (recvTimeout_ > 0) { + setRecvTimeout(recvTimeout_); + } + + // Linger + setLinger(lingerOn_, lingerVal_); + + // No delay + setNoDelay(noDelay_); + + // Set the socket to be non blocking for connect if a timeout exists + int flags = fcntl(socket_, F_GETFL, 0); + if (connTimeout_ > 0) { + if (-1 == fcntl(socket_, F_SETFL, flags | O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() fcntl() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + } else { + if (-1 == fcntl(socket_, F_SETFL, flags & ~O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() fcntl " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + } + + // Connect the socket + int ret = connect(socket_, res->ai_addr, res->ai_addrlen); + + // success case + if (ret == 0) { + goto done; + } + + if (errno != EINPROGRESS) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", errno_copy); + } + + + struct pollfd fds[1]; + std::memset(fds, 0 , sizeof(fds)); + fds[0].fd = socket_; + fds[0].events = POLLOUT; + ret = poll(fds, 1, connTimeout_); + + if (ret > 0) { + // Ensure the socket is connected and that there are no errors set + int val; + socklen_t lon; + lon = sizeof(int); + int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, (void *)&val, &lon); + if (ret2 == -1) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()", errno_copy); + } + // no errors on socket, go to town + if (val == 0) { + goto done; + } + GlobalOutput.perror("TSocket::open() error on socket (after poll) " + getSocketInfo(), val); + throw TTransportException(TTransportException::NOT_OPEN, "socket open() error", val); + } else if (ret == 0) { + // socket timed out + string errStr = "TSocket::open() timed out " + getSocketInfo(); + GlobalOutput(errStr.c_str()); + throw TTransportException(TTransportException::NOT_OPEN, "open() timed out"); + } else { + // error on poll() + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() poll() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "poll() failed", errno_copy); + } + + done: + // Set socket back to normal mode (blocking) + fcntl(socket_, F_SETFL, flags); +} + +void TSocket::open() { + if (isOpen()) { + throw TTransportException(TTransportException::ALREADY_OPEN); + } + + // Validate port number + if (port_ < 0 || port_ > 65536) { + throw TTransportException(TTransportException::NOT_OPEN, "Specified port is invalid"); + } + + struct addrinfo hints, *res, *res0; + res = NULL; + res0 = NULL; + int error; + char port[sizeof("65536")]; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + error = getaddrinfo(host_.c_str(), port, &hints, &res0); + + if (error) { + string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo() + string(gai_strerror(error)); + GlobalOutput(errStr.c_str()); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for client socket."); + } + + // Cycle through all the returned addresses until one + // connects or push the exception up. + for (res = res0; res; res = res->ai_next) { + try { + openConnection(res); + break; + } catch (TTransportException& ttx) { + if (res->ai_next) { + close(); + } else { + close(); + freeaddrinfo(res0); // cleanup on failure + throw; + } + } + } + + // Free address structure memory + freeaddrinfo(res0); +} + +void TSocket::close() { + if (socket_ >= 0) { + shutdown(socket_, SHUT_RDWR); + ::close(socket_); + } + socket_ = -1; +} + +uint32_t TSocket::read(uint8_t* buf, uint32_t len) { + if (socket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket"); + } + + int32_t retries = 0; + + // EAGAIN can be signalled both when a timeout has occurred and when + // the system is out of resources (an awesome undocumented feature). + // The following is an approximation of the time interval under which + // EAGAIN is taken to indicate an out of resources error. + uint32_t eagainThresholdMicros = 0; + if (recvTimeout_) { + // if a readTimeout is specified along with a max number of recv retries, then + // the threshold will ensure that the read timeout is not exceeded even in the + // case of resource errors + eagainThresholdMicros = (recvTimeout_*1000)/ ((maxRecvRetries_>0) ? maxRecvRetries_ : 2); + } + + try_again: + // Read from the socket + struct timeval begin; + gettimeofday(&begin, NULL); + int got = recv(socket_, buf, len, 0); + int errno_copy = errno; //gettimeofday can change errno + struct timeval end; + gettimeofday(&end, NULL); + uint32_t readElapsedMicros = (((end.tv_sec - begin.tv_sec) * 1000 * 1000) + + (((uint64_t)(end.tv_usec - begin.tv_usec)))); + ++g_socket_syscalls; + + // Check for error on read + if (got < 0) { + if (errno_copy == EAGAIN) { + // check if this is the lack of resources or timeout case + if (!eagainThresholdMicros || (readElapsedMicros < eagainThresholdMicros)) { + if (retries++ < maxRecvRetries_) { + usleep(50); + goto try_again; + } else { + throw TTransportException(TTransportException::TIMED_OUT, + "EAGAIN (unavailable resources)"); + } + } else { + // infer that timeout has been hit + throw TTransportException(TTransportException::TIMED_OUT, + "EAGAIN (timed out)"); + } + } + + // If interrupted, try again + if (errno_copy == EINTR && retries++ < maxRecvRetries_) { + goto try_again; + } + + // Now it's not a try again case, but a real probblez + GlobalOutput.perror("TSocket::read() recv() " + getSocketInfo(), errno_copy); + + // If we disconnect with no linger time + if (errno_copy == ECONNRESET) { + #ifdef __FreeBSD__ + /* shigin: freebsd doesn't follow POSIX semantic of recv and fails with + * ECONNRESET if peer performed shutdown + */ + close(); + return 0; + #else + throw TTransportException(TTransportException::NOT_OPEN, "ECONNRESET"); + #endif + } + + // This ish isn't open + if (errno_copy == ENOTCONN) { + throw TTransportException(TTransportException::NOT_OPEN, "ENOTCONN"); + } + + // Timed out! + if (errno_copy == ETIMEDOUT) { + throw TTransportException(TTransportException::TIMED_OUT, "ETIMEDOUT"); + } + + // Some other error, whatevz + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } + + // The remote host has closed the socket + if (got == 0) { + close(); + return 0; + } + + // Pack data into string + return got; +} + +void TSocket::write(const uint8_t* buf, uint32_t len) { + if (socket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Called write on non-open socket"); + } + + uint32_t sent = 0; + + while (sent < len) { + + int flags = 0; + #ifdef MSG_NOSIGNAL + // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we + // check for the EPIPE return condition and close the socket in that case + flags |= MSG_NOSIGNAL; + #endif // ifdef MSG_NOSIGNAL + + int b = send(socket_, buf + sent, len - sent, flags); + ++g_socket_syscalls; + + // Fail on a send error + if (b < 0) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::write() send() " + getSocketInfo(), errno_copy); + + if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { + close(); + throw TTransportException(TTransportException::NOT_OPEN, "write() send()", errno_copy); + } + + throw TTransportException(TTransportException::UNKNOWN, "write() send()", errno_copy); + } + + // Fail on blocked send + if (b == 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Socket send returned 0."); + } + sent += b; + } +} + +std::string TSocket::getHost() { + return host_; +} + +int TSocket::getPort() { + return port_; +} + +void TSocket::setHost(string host) { + host_ = host; +} + +void TSocket::setPort(int port) { + port_ = port; +} + +void TSocket::setLinger(bool on, int linger) { + lingerOn_ = on; + lingerVal_ = linger; + if (socket_ < 0) { + return; + } + + struct linger l = {(lingerOn_ ? 1 : 0), lingerVal_}; + int ret = setsockopt(socket_, SOL_SOCKET, SO_LINGER, &l, sizeof(l)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setLinger() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setNoDelay(bool noDelay) { + noDelay_ = noDelay; + if (socket_ < 0) { + return; + } + + // Set socket to NODELAY + int v = noDelay_ ? 1 : 0; + int ret = setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setNoDelay() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setConnTimeout(int ms) { + connTimeout_ = ms; +} + +void TSocket::setRecvTimeout(int ms) { + if (ms < 0) { + char errBuf[512]; + sprintf(errBuf, "TSocket::setRecvTimeout with negative input: %d\n", ms); + GlobalOutput(errBuf); + return; + } + recvTimeout_ = ms; + + if (socket_ < 0) { + return; + } + + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); + + // Copy because poll may modify + struct timeval r = recvTimeval_; + int ret = setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &r, sizeof(r)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setRecvTimeout() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setSendTimeout(int ms) { + if (ms < 0) { + char errBuf[512]; + sprintf(errBuf, "TSocket::setSendTimeout with negative input: %d\n", ms); + GlobalOutput(errBuf); + return; + } + sendTimeout_ = ms; + + if (socket_ < 0) { + return; + } + + struct timeval s = {(int)(sendTimeout_/1000), + (int)((sendTimeout_%1000)*1000)}; + int ret = setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &s, sizeof(s)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setSendTimeout() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setMaxRecvRetries(int maxRecvRetries) { + maxRecvRetries_ = maxRecvRetries; +} + +string TSocket::getSocketInfo() { + std::ostringstream oss; + oss << "<Host: " << host_ << " Port: " << port_ << ">"; + return oss.str(); +} + +std::string TSocket::getPeerHost() { + if (peerHost_.empty()) { + struct sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + + if (socket_ < 0) { + return host_; + } + + int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen); + + if (rv != 0) { + return peerHost_; + } + + char clienthost[NI_MAXHOST]; + char clientservice[NI_MAXSERV]; + + getnameinfo((sockaddr*) &addr, addrLen, + clienthost, sizeof(clienthost), + clientservice, sizeof(clientservice), 0); + + peerHost_ = clienthost; + } + return peerHost_; +} + +std::string TSocket::getPeerAddress() { + if (peerAddress_.empty()) { + struct sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + + if (socket_ < 0) { + return peerAddress_; + } + + int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen); + + if (rv != 0) { + return peerAddress_; + } + + char clienthost[NI_MAXHOST]; + char clientservice[NI_MAXSERV]; + + getnameinfo((sockaddr*) &addr, addrLen, + clienthost, sizeof(clienthost), + clientservice, sizeof(clientservice), + NI_NUMERICHOST|NI_NUMERICSERV); + + peerAddress_ = clienthost; + peerPort_ = std::atoi(clientservice); + } + return peerAddress_; +} + +int TSocket::getPeerPort() { + getPeerAddress(); + return peerPort_; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSocket.h b/lib/cpp/src/transport/TSocket.h new file mode 100644 index 000000000..b0f445aa3 --- /dev/null +++ b/lib/cpp/src/transport/TSocket.h @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSOCKET_H_ +#define _THRIFT_TRANSPORT_TSOCKET_H_ 1 + +#include <string> +#include <sys/time.h> + +#include "TTransport.h" +#include "TServerSocket.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * TCP Socket implementation of the TTransport interface. + * + */ +class TSocket : public TTransport { + /** + * We allow the TServerSocket acceptImpl() method to access the private + * members of a socket so that it can access the TSocket(int socket) + * constructor which creates a socket object from the raw UNIX socket + * handle. + */ + friend class TServerSocket; + + public: + /** + * Constructs a new socket. Note that this does NOT actually connect the + * socket. + * + */ + TSocket(); + + /** + * Constructs a new socket. Note that this does NOT actually connect the + * socket. + * + * @param host An IP address or hostname to connect to + * @param port The port to connect on + */ + TSocket(std::string host, int port); + + /** + * Destroyes the socket object, closing it if necessary. + */ + virtual ~TSocket(); + + /** + * Whether the socket is alive. + * + * @return Is the socket alive? + */ + bool isOpen(); + + /** + * Calls select on the socket to see if there is more data available. + */ + bool peek(); + + /** + * Creates and opens the UNIX socket. + * + * @throws TTransportException If the socket could not connect + */ + virtual void open(); + + /** + * Shuts down communications on the socket. + */ + void close(); + + /** + * Reads from the underlying socket. + */ + uint32_t read(uint8_t* buf, uint32_t len); + + /** + * Writes to the underlying socket. + */ + void write(const uint8_t* buf, uint32_t len); + + /** + * Get the host that the socket is connected to + * + * @return string host identifier + */ + std::string getHost(); + + /** + * Get the port that the socket is connected to + * + * @return int port number + */ + int getPort(); + + /** + * Set the host that socket will connect to + * + * @param host host identifier + */ + void setHost(std::string host); + + /** + * Set the port that socket will connect to + * + * @param port port number + */ + void setPort(int port); + + /** + * Controls whether the linger option is set on the socket. + * + * @param on Whether SO_LINGER is on + * @param linger If linger is active, the number of seconds to linger for + */ + void setLinger(bool on, int linger); + + /** + * Whether to enable/disable Nagle's algorithm. + * + * @param noDelay Whether or not to disable the algorithm. + * @return + */ + void setNoDelay(bool noDelay); + + /** + * Set the connect timeout + */ + void setConnTimeout(int ms); + + /** + * Set the receive timeout + */ + void setRecvTimeout(int ms); + + /** + * Set the send timeout + */ + void setSendTimeout(int ms); + + /** + * Set the max number of recv retries in case of an EAGAIN + * error + */ + void setMaxRecvRetries(int maxRecvRetries); + + /** + * Get socket information formated as a string <Host: x Port: x> + */ + std::string getSocketInfo(); + + /** + * Returns the DNS name of the host to which the socket is connected + */ + std::string getPeerHost(); + + /** + * Returns the address of the host to which the socket is connected + */ + std::string getPeerAddress(); + + /** + * Returns the port of the host to which the socket is connected + **/ + int getPeerPort(); + + + protected: + /** + * Constructor to create socket from raw UNIX handle. Never called directly + * but used by the TServerSocket class. + */ + TSocket(int socket); + + /** connect, called by open */ + void openConnection(struct addrinfo *res); + + /** Host to connect to */ + std::string host_; + + /** Peer hostname */ + std::string peerHost_; + + /** Peer address */ + std::string peerAddress_; + + /** Peer port */ + int peerPort_; + + /** Port number to connect on */ + int port_; + + /** Underlying UNIX socket handle */ + int socket_; + + /** Connect timeout in ms */ + int connTimeout_; + + /** Send timeout in ms */ + int sendTimeout_; + + /** Recv timeout in ms */ + int recvTimeout_; + + /** Linger on */ + bool lingerOn_; + + /** Linger val */ + int lingerVal_; + + /** Nodelay */ + bool noDelay_; + + /** Recv EGAIN retries */ + int maxRecvRetries_; + + /** Recv timeout timeval */ + struct timeval recvTimeval_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSOCKET_H_ + diff --git a/lib/cpp/src/transport/TSocketPool.cpp b/lib/cpp/src/transport/TSocketPool.cpp new file mode 100644 index 000000000..1150282bb --- /dev/null +++ b/lib/cpp/src/transport/TSocketPool.cpp @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <algorithm> +#include <iostream> + +#include "TSocketPool.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +using boost::shared_ptr; + +/** + * TSocketPoolServer implementation + * + */ +TSocketPoolServer::TSocketPoolServer() + : host_(""), + port_(0), + socket_(-1), + lastFailTime_(0), + consecutiveFailures_(0) {} + +/** + * Constructor for TSocketPool server + */ +TSocketPoolServer::TSocketPoolServer(const string &host, int port) + : host_(host), + port_(port), + socket_(-1), + lastFailTime_(0), + consecutiveFailures_(0) {} + +/** + * TSocketPool implementation. + * + */ + +TSocketPool::TSocketPool() : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) { +} + +TSocketPool::TSocketPool(const vector<string> &hosts, + const vector<int> &ports) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + if (hosts.size() != ports.size()) { + GlobalOutput("TSocketPool::TSocketPool: hosts.size != ports.size"); + throw TTransportException(TTransportException::BAD_ARGS); + } + + for (unsigned int i = 0; i < hosts.size(); ++i) { + addServer(hosts[i], ports[i]); + } +} + +TSocketPool::TSocketPool(const vector<pair<string, int> >& servers) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + for (unsigned i = 0; i < servers.size(); ++i) { + addServer(servers[i].first, servers[i].second); + } +} + +TSocketPool::TSocketPool(const vector< shared_ptr<TSocketPoolServer> >& servers) : TSocket(), + servers_(servers), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ +} + +TSocketPool::TSocketPool(const string& host, int port) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + addServer(host, port); +} + +TSocketPool::~TSocketPool() { + vector< shared_ptr<TSocketPoolServer> >::const_iterator iter = servers_.begin(); + vector< shared_ptr<TSocketPoolServer> >::const_iterator iterEnd = servers_.end(); + for (; iter != iterEnd; ++iter) { + setCurrentServer(*iter); + TSocketPool::close(); + } +} + +void TSocketPool::addServer(const string& host, int port) { + servers_.push_back(shared_ptr<TSocketPoolServer>(new TSocketPoolServer(host, port))); +} + +void TSocketPool::setServers(const vector< shared_ptr<TSocketPoolServer> >& servers) { + servers_ = servers; +} + +void TSocketPool::getServers(vector< shared_ptr<TSocketPoolServer> >& servers) { + servers = servers_; +} + +void TSocketPool::setNumRetries(int numRetries) { + numRetries_ = numRetries; +} + +void TSocketPool::setRetryInterval(int retryInterval) { + retryInterval_ = retryInterval; +} + + +void TSocketPool::setMaxConsecutiveFailures(int maxConsecutiveFailures) { + maxConsecutiveFailures_ = maxConsecutiveFailures; +} + +void TSocketPool::setRandomize(bool randomize) { + randomize_ = randomize; +} + +void TSocketPool::setAlwaysTryLast(bool alwaysTryLast) { + alwaysTryLast_ = alwaysTryLast; +} + +void TSocketPool::setCurrentServer(const shared_ptr<TSocketPoolServer> &server) { + currentServer_ = server; + host_ = server->host_; + port_ = server->port_; + socket_ = server->socket_; +} + +/* TODO: without apc we ignore a lot of functionality from the php version */ +void TSocketPool::open() { + if (randomize_) { + random_shuffle(servers_.begin(), servers_.end()); + } + + unsigned int numServers = servers_.size(); + for (unsigned int i = 0; i < numServers; ++i) { + + shared_ptr<TSocketPoolServer> &server = servers_[i]; + bool retryIntervalPassed = (server->lastFailTime_ == 0); + bool isLastServer = alwaysTryLast_ ? (i == (numServers - 1)) : false; + + // Impersonate the server socket + setCurrentServer(server); + + if (isOpen()) { + // already open means we're done + return; + } + + if (server->lastFailTime_ > 0) { + // The server was marked as down, so check if enough time has elapsed to retry + int elapsedTime = time(NULL) - server->lastFailTime_; + if (elapsedTime > retryInterval_) { + retryIntervalPassed = true; + } + } + + if (retryIntervalPassed || isLastServer) { + for (int j = 0; j < numRetries_; ++j) { + try { + TSocket::open(); + + // Copy over the opened socket so that we can keep it persistent + server->socket_ = socket_; + + // reset lastFailTime_ is required + if (server->lastFailTime_) { + server->lastFailTime_ = 0; + } + + // success + return; + } catch (TException e) { + string errStr = "TSocketPool::open failed "+getSocketInfo()+": "+e.what(); + GlobalOutput(errStr.c_str()); + // connection failed + } + } + + ++server->consecutiveFailures_; + if (server->consecutiveFailures_ > maxConsecutiveFailures_) { + // Mark server as down + server->consecutiveFailures_ = 0; + server->lastFailTime_ = time(NULL); + } + } + } + + GlobalOutput("TSocketPool::open: all connections failed"); + throw TTransportException(TTransportException::NOT_OPEN); +} + +void TSocketPool::close() { + if (isOpen()) { + TSocket::close(); + currentServer_->socket_ = -1; + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSocketPool.h b/lib/cpp/src/transport/TSocketPool.h new file mode 100644 index 000000000..8c506695a --- /dev/null +++ b/lib/cpp/src/transport/TSocketPool.h @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_ +#define _THRIFT_TRANSPORT_TSOCKETPOOL_H_ 1 + +#include <vector> +#include "TSocket.h" + +namespace apache { namespace thrift { namespace transport { + + /** + * Class to hold server information for TSocketPool + * + */ +class TSocketPoolServer { + + public: + /** + * Default constructor for server info + */ + TSocketPoolServer(); + + /** + * Constructor for TSocketPool server + */ + TSocketPoolServer(const std::string &host, int port); + + // Host name + std::string host_; + + // Port to connect on + int port_; + + // Socket for the server + int socket_; + + // Last time connecting to this server failed + int lastFailTime_; + + // Number of consecutive times connecting to this server failed + int consecutiveFailures_; +}; + +/** + * TCP Socket implementation of the TTransport interface. + * + */ +class TSocketPool : public TSocket { + + public: + + /** + * Socket pool constructor + */ + TSocketPool(); + + /** + * Socket pool constructor + * + * @param hosts list of host names + * @param ports list of port names + */ + TSocketPool(const std::vector<std::string> &hosts, + const std::vector<int> &ports); + + /** + * Socket pool constructor + * + * @param servers list of pairs of host name and port + */ + TSocketPool(const std::vector<std::pair<std::string, int> >& servers); + + /** + * Socket pool constructor + * + * @param servers list of TSocketPoolServers + */ + TSocketPool(const std::vector< boost::shared_ptr<TSocketPoolServer> >& servers); + + /** + * Socket pool constructor + * + * @param host single host + * @param port single port + */ + TSocketPool(const std::string& host, int port); + + /** + * Destroyes the socket object, closing it if necessary. + */ + virtual ~TSocketPool(); + + /** + * Add a server to the pool + */ + void addServer(const std::string& host, int port); + + /** + * Set list of servers in this pool + */ + void setServers(const std::vector< boost::shared_ptr<TSocketPoolServer> >& servers); + + /** + * Get list of servers in this pool + */ + void getServers(std::vector< boost::shared_ptr<TSocketPoolServer> >& servers); + + /** + * Sets how many times to keep retrying a host in the connect function. + */ + void setNumRetries(int numRetries); + + /** + * Sets how long to wait until retrying a host if it was marked down + */ + void setRetryInterval(int retryInterval); + + /** + * Sets how many times to keep retrying a host before marking it as down. + */ + void setMaxConsecutiveFailures(int maxConsecutiveFailures); + + /** + * Turns randomization in connect order on or off. + */ + void setRandomize(bool randomize); + + /** + * Whether to always try the last server. + */ + void setAlwaysTryLast(bool alwaysTryLast); + + /** + * Creates and opens the UNIX socket. + */ + void open(); + + /* + * Closes the UNIX socket + */ + void close(); + + protected: + + void setCurrentServer(const boost::shared_ptr<TSocketPoolServer> &server); + + /** List of servers to connect to */ + std::vector< boost::shared_ptr<TSocketPoolServer> > servers_; + + /** Current server */ + boost::shared_ptr<TSocketPoolServer> currentServer_; + + /** How many times to retry each host in connect */ + int numRetries_; + + /** Retry interval in seconds, how long to not try a host if it has been + * marked as down. + */ + int retryInterval_; + + /** Max consecutive failures before marking a host down. */ + int maxConsecutiveFailures_; + + /** Try hosts in order? or Randomized? */ + bool randomize_; + + /** Always try last host, even if marked down? */ + bool alwaysTryLast_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_ + diff --git a/lib/cpp/src/transport/TTransport.h b/lib/cpp/src/transport/TTransport.h new file mode 100644 index 000000000..eb0d5df8a --- /dev/null +++ b/lib/cpp/src/transport/TTransport.h @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1 + +#include <Thrift.h> +#include <boost/shared_ptr.hpp> +#include <transport/TTransportException.h> +#include <string> + +namespace apache { namespace thrift { namespace transport { + +/** + * Generic interface for a method of transporting data. A TTransport may be + * capable of either reading or writing, but not necessarily both. + * + */ +class TTransport { + public: + /** + * Virtual deconstructor. + */ + virtual ~TTransport() {} + + /** + * Whether this transport is open. + */ + virtual bool isOpen() { + return false; + } + + /** + * Tests whether there is more data to read or if the remote side is + * still open. By default this is true whenever the transport is open, + * but implementations should add logic to test for this condition where + * possible (i.e. on a socket). + * This is used by a server to check if it should listen for another + * request. + */ + virtual bool peek() { + return isOpen(); + } + + /** + * Opens the transport for communications. + * + * @return bool Whether the transport was successfully opened + * @throws TTransportException if opening failed + */ + virtual void open() { + throw TTransportException(TTransportException::NOT_OPEN, "Cannot open base TTransport."); + } + + /** + * Closes the transport. + */ + virtual void close() { + throw TTransportException(TTransportException::NOT_OPEN, "Cannot close base TTransport."); + } + + /** + * Attempt to read up to the specified number of bytes into the string. + * + * @param buf Reference to the location to write the data + * @param len How many bytes to read + * @return How many bytes were actually read + * @throws TTransportException If an error occurs + */ + virtual uint32_t read(uint8_t* /* buf */, uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot read."); + } + + /** + * Reads the given amount of data in its entirety no matter what. + * + * @param s Reference to location for read data + * @param len How many bytes to read + * @return How many bytes read, which must be equal to size + * @throws TTransportException If insufficient data was read + */ + virtual uint32_t readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TTransportException("No more data to read."); + } + have += get; + } + + return have; + } + + /** + * Called when read is completed. + * This can be over-ridden to perform a transport-specific action + * e.g. logging the request to a file + * + */ + virtual void readEnd() { + // default behaviour is to do nothing + return; + } + + /** + * Writes the string in its entirety to the buffer. + * + * @param buf The data to write out + * @throws TTransportException if an error occurs + */ + virtual void write(const uint8_t* /* buf */, uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot write."); + } + + /** + * Called when write is completed. + * This can be over-ridden to perform a transport-specific action + * at the end of a request. + * + */ + virtual void writeEnd() { + // default behaviour is to do nothing + return; + } + + /** + * Flushes any pending data to be written. Typically used with buffered + * transport mechanisms. + * + * @throws TTransportException if an error occurs + */ + virtual void flush() {} + + /** + * Attempts to return a pointer to \c len bytes, possibly copied into \c buf. + * Does not consume the bytes read (i.e.: a later read will return the same + * data). This method is meant to support protocols that need to read + * variable-length fields. They can attempt to borrow the maximum amount of + * data that they will need, then consume (see next method) what they + * actually use. Some transports will not support this method and others + * will fail occasionally, so protocols must be prepared to use read if + * borrow fails. + * + * @oaram buf A buffer where the data can be stored if needed. + * If borrow doesn't return buf, then the contents of + * buf after the call are undefined. + * @param len *len should initially contain the number of bytes to borrow. + * If borrow succeeds, *len will contain the number of bytes + * available in the returned pointer. This will be at least + * what was requested, but may be more if borrow returns + * a pointer to an internal buffer, rather than buf. + * If borrow fails, the contents of *len are undefined. + * @return If the borrow succeeds, return a pointer to the borrowed data. + * This might be equal to \c buf, or it might be a pointer into + * the transport's internal buffers. + * @throws TTransportException if an error occurs + */ + virtual const uint8_t* borrow(uint8_t* /* buf */, uint32_t* /* len */) { + return NULL; + } + + /** + * Remove len bytes from the transport. This should always follow a borrow + * of at least len bytes, and should always succeed. + * TODO(dreiss): Is there any transport that could borrow but fail to + * consume, or that would require a buffer to dump the consumed data? + * + * @param len How many bytes to consume + * @throws TTransportException If an error occurs + */ + virtual void consume(uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot consume."); + } + + protected: + /** + * Simple constructor. + */ + TTransport() {} +}; + +/** + * Generic factory class to make an input and output transport out of a + * source transport. Commonly used inside servers to make input and output + * streams out of raw clients. + * + */ +class TTransportFactory { + public: + TTransportFactory() {} + + virtual ~TTransportFactory() {} + + /** + * Default implementation does nothing, just returns the transport given. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) { + return trans; + } + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TTransportException.cpp b/lib/cpp/src/transport/TTransportException.cpp new file mode 100644 index 000000000..f0aaedc2f --- /dev/null +++ b/lib/cpp/src/transport/TTransportException.cpp @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <transport/TTransportException.h> +#include <boost/lexical_cast.hpp> +#include <cstring> +#include <config.h> + +using std::string; +using boost::lexical_cast; + +namespace apache { namespace thrift { namespace transport { + +}}} // apache::thrift::transport + diff --git a/lib/cpp/src/transport/TTransportException.h b/lib/cpp/src/transport/TTransportException.h new file mode 100644 index 000000000..330785cea --- /dev/null +++ b/lib/cpp/src/transport/TTransportException.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ +#define _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ 1 + +#include <string> +#include <Thrift.h> + +namespace apache { namespace thrift { namespace transport { + +/** + * Class to encapsulate all the possible types of transport errors that may + * occur in various transport systems. This provides a sort of generic + * wrapper around the shitty UNIX E_ error codes that lets a common code + * base of error handling to be used for various types of transports, i.e. + * pipes etc. + * + */ +class TTransportException : public apache::thrift::TException { + public: + /** + * Error codes for the various types of exceptions. + */ + enum TTransportExceptionType + { UNKNOWN = 0 + , NOT_OPEN = 1 + , ALREADY_OPEN = 2 + , TIMED_OUT = 3 + , END_OF_FILE = 4 + , INTERRUPTED = 5 + , BAD_ARGS = 6 + , CORRUPTED_DATA = 7 + , INTERNAL_ERROR = 8 + }; + + TTransportException() : + apache::thrift::TException(), + type_(UNKNOWN) {} + + TTransportException(TTransportExceptionType type) : + apache::thrift::TException(), + type_(type) {} + + TTransportException(const std::string& message) : + apache::thrift::TException(message), + type_(UNKNOWN) {} + + TTransportException(TTransportExceptionType type, const std::string& message) : + apache::thrift::TException(message), + type_(type) {} + + TTransportException(TTransportExceptionType type, + const std::string& message, + int errno_copy) : + apache::thrift::TException(message + ": " + TOutput::strerror_s(errno_copy)), + type_(type) {} + + virtual ~TTransportException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TTransportExceptionType getType() const throw() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TTransportException: Unknown transport exception"; + case NOT_OPEN : return "TTransportException: Transport not open"; + case ALREADY_OPEN : return "TTransportException: Transport already open"; + case TIMED_OUT : return "TTransportException: Timed out"; + case END_OF_FILE : return "TTransportException: End of file"; + case INTERRUPTED : return "TTransportException: Interrupted"; + case BAD_ARGS : return "TTransportException: Invalid arguments"; + case CORRUPTED_DATA : return "TTransportException: Corrupted Data"; + case INTERNAL_ERROR : return "TTransportException: Internal error"; + default : return "TTransportException: (Invalid exception type)"; + } + } else { + return message_.c_str(); + } + } + + protected: + /** Just like strerror_r but returns a C++ string object. */ + std::string strerror_s(int errno_copy); + + /** Error code */ + TTransportExceptionType type_; + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ diff --git a/lib/cpp/src/transport/TTransportUtils.cpp b/lib/cpp/src/transport/TTransportUtils.cpp new file mode 100644 index 000000000..a840fa6c1 --- /dev/null +++ b/lib/cpp/src/transport/TTransportUtils.cpp @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <transport/TTransportUtils.h> + +using std::string; + +namespace apache { namespace thrift { namespace transport { + +uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) { + uint32_t need = len; + + // We don't have enough data yet + if (rLen_-rPos_ < need) { + // Copy out whatever we have + if (rLen_-rPos_ > 0) { + memcpy(buf, rBuf_+rPos_, rLen_-rPos_); + need -= rLen_-rPos_; + buf += rLen_-rPos_; + rPos_ = rLen_; + } + + // Double the size of the underlying buffer if it is full + if (rLen_ == rBufSize_) { + rBufSize_ *=2; + rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_); + } + + // try to fill up the buffer + rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_); + } + + + // Hand over whatever we have + uint32_t give = need; + if (rLen_-rPos_ < give) { + give = rLen_-rPos_; + } + if (give > 0) { + memcpy(buf, rBuf_+rPos_, give); + rPos_ += give; + need -= give; + } + + return (len - need); +} + +void TPipedTransport::write(const uint8_t* buf, uint32_t len) { + if (len == 0) { + return; + } + + // Make the buffer as big as it needs to be + if ((len + wLen_) >= wBufSize_) { + uint32_t newBufSize = wBufSize_*2; + while ((len + wLen_) >= newBufSize) { + newBufSize *= 2; + } + wBuf_ = (uint8_t *)std::realloc(wBuf_, sizeof(uint8_t) * newBufSize); + wBufSize_ = newBufSize; + } + + // Copy into the buffer + memcpy(wBuf_ + wLen_, buf, len); + wLen_ += len; +} + +void TPipedTransport::flush() { + // Write out any data waiting in the write buffer + if (wLen_ > 0) { + srcTrans_->write(wBuf_, wLen_); + wLen_ = 0; + } + + // Flush the underlying transport + srcTrans_->flush(); +} + +TPipedFileReaderTransport::TPipedFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans, boost::shared_ptr<TTransport> dstTrans) + : TPipedTransport(srcTrans, dstTrans), + srcTrans_(srcTrans) { +} + +TPipedFileReaderTransport::~TPipedFileReaderTransport() { +} + +bool TPipedFileReaderTransport::isOpen() { + return TPipedTransport::isOpen(); +} + +bool TPipedFileReaderTransport::peek() { + return TPipedTransport::peek(); +} + +void TPipedFileReaderTransport::open() { + TPipedTransport::open(); +} + +void TPipedFileReaderTransport::close() { + TPipedTransport::close(); +} + +uint32_t TPipedFileReaderTransport::read(uint8_t* buf, uint32_t len) { + return TPipedTransport::read(buf, len); +} + +uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TEOFException(); + } + have += get; + } + + return have; +} + +void TPipedFileReaderTransport::readEnd() { + TPipedTransport::readEnd(); +} + +void TPipedFileReaderTransport::write(const uint8_t* buf, uint32_t len) { + TPipedTransport::write(buf, len); +} + +void TPipedFileReaderTransport::writeEnd() { + TPipedTransport::writeEnd(); +} + +void TPipedFileReaderTransport::flush() { + TPipedTransport::flush(); +} + +int32_t TPipedFileReaderTransport::getReadTimeout() { + return srcTrans_->getReadTimeout(); +} + +void TPipedFileReaderTransport::setReadTimeout(int32_t readTimeout) { + srcTrans_->setReadTimeout(readTimeout); +} + +uint32_t TPipedFileReaderTransport::getNumChunks() { + return srcTrans_->getNumChunks(); +} + +uint32_t TPipedFileReaderTransport::getCurChunk() { + return srcTrans_->getCurChunk(); +} + +void TPipedFileReaderTransport::seekToChunk(int32_t chunk) { + srcTrans_->seekToChunk(chunk); +} + +void TPipedFileReaderTransport::seekToEnd() { + srcTrans_->seekToEnd(); +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TTransportUtils.h b/lib/cpp/src/transport/TTransportUtils.h new file mode 100644 index 000000000..d65c91674 --- /dev/null +++ b/lib/cpp/src/transport/TTransportUtils.h @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ +#define _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ 1 + +#include <cstdlib> +#include <cstring> +#include <string> +#include <algorithm> +#include <transport/TTransport.h> +// Include the buffered transports that used to be defined here. +#include <transport/TBufferTransports.h> +#include <transport/TFileTransport.h> + +namespace apache { namespace thrift { namespace transport { + +/** + * The null transport is a dummy transport that doesn't actually do anything. + * It's sort of an analogy to /dev/null, you can never read anything from it + * and it will let you write anything you want to it, though it won't actually + * go anywhere. + * + */ +class TNullTransport : public TTransport { + public: + TNullTransport() {} + + ~TNullTransport() {} + + bool isOpen() { + return true; + } + + void open() {} + + void write(const uint8_t* /* buf */, uint32_t /* len */) { + return; + } + +}; + + +/** + * TPipedTransport. This transport allows piping of a request from one + * transport to another either when readEnd() or writeEnd(). The typical + * use case for this is to log a request or a reply to disk. + * The underlying buffer expands to a keep a copy of the entire + * request/response. + * + */ +class TPipedTransport : virtual public TTransport { + public: + TPipedTransport(boost::shared_ptr<TTransport> srcTrans, + boost::shared_ptr<TTransport> dstTrans) : + srcTrans_(srcTrans), + dstTrans_(dstTrans), + rBufSize_(512), rPos_(0), rLen_(0), + wBufSize_(512), wLen_(0) { + + // default is to to pipe the request when readEnd() is called + pipeOnRead_ = true; + pipeOnWrite_ = false; + + rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_); + wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_); + } + + TPipedTransport(boost::shared_ptr<TTransport> srcTrans, + boost::shared_ptr<TTransport> dstTrans, + uint32_t sz) : + srcTrans_(srcTrans), + dstTrans_(dstTrans), + rBufSize_(512), rPos_(0), rLen_(0), + wBufSize_(sz), wLen_(0) { + + rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_); + wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_); + } + + ~TPipedTransport() { + std::free(rBuf_); + std::free(wBuf_); + } + + bool isOpen() { + return srcTrans_->isOpen(); + } + + bool peek() { + if (rPos_ >= rLen_) { + // Double the size of the underlying buffer if it is full + if (rLen_ == rBufSize_) { + rBufSize_ *=2; + rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_); + } + + // try to fill up the buffer + rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_); + } + return (rLen_ > rPos_); + } + + + void open() { + srcTrans_->open(); + } + + void close() { + srcTrans_->close(); + } + + void setPipeOnRead(bool pipeVal) { + pipeOnRead_ = pipeVal; + } + + void setPipeOnWrite(bool pipeVal) { + pipeOnWrite_ = pipeVal; + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void readEnd() { + + if (pipeOnRead_) { + dstTrans_->write(rBuf_, rPos_); + dstTrans_->flush(); + } + + srcTrans_->readEnd(); + + // If requests are being pipelined, copy down our read-ahead data, + // then reset our state. + int read_ahead = rLen_ - rPos_; + memcpy(rBuf_, rBuf_ + rPos_, read_ahead); + rPos_ = 0; + rLen_ = read_ahead; + } + + void write(const uint8_t* buf, uint32_t len); + + void writeEnd() { + if (pipeOnWrite_) { + dstTrans_->write(wBuf_, wLen_); + dstTrans_->flush(); + } + } + + void flush(); + + boost::shared_ptr<TTransport> getTargetTransport() { + return dstTrans_; + } + + protected: + boost::shared_ptr<TTransport> srcTrans_; + boost::shared_ptr<TTransport> dstTrans_; + + uint8_t* rBuf_; + uint32_t rBufSize_; + uint32_t rPos_; + uint32_t rLen_; + + uint8_t* wBuf_; + uint32_t wBufSize_; + uint32_t wLen_; + + bool pipeOnRead_; + bool pipeOnWrite_; +}; + + +/** + * Wraps a transport into a pipedTransport instance. + * + */ +class TPipedTransportFactory : public TTransportFactory { + public: + TPipedTransportFactory() {} + TPipedTransportFactory(boost::shared_ptr<TTransport> dstTrans) { + initializeTargetTransport(dstTrans); + } + virtual ~TPipedTransportFactory() {} + + /** + * Wraps the base transport into a piped transport. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> srcTrans) { + return boost::shared_ptr<TTransport>(new TPipedTransport(srcTrans, dstTrans_)); + } + + virtual void initializeTargetTransport(boost::shared_ptr<TTransport> dstTrans) { + if (dstTrans_.get() == NULL) { + dstTrans_ = dstTrans; + } else { + throw TException("Target transport already initialized"); + } + } + + protected: + boost::shared_ptr<TTransport> dstTrans_; +}; + +/** + * TPipedFileTransport. This is just like a TTransport, except that + * it is a templatized class, so that clients who rely on a specific + * TTransport can still access the original transport. + * + */ +class TPipedFileReaderTransport : public TPipedTransport, + public TFileReaderTransport { + public: + TPipedFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans, boost::shared_ptr<TTransport> dstTrans); + + ~TPipedFileReaderTransport(); + + // TTransport functions + bool isOpen(); + bool peek(); + void open(); + void close(); + uint32_t read(uint8_t* buf, uint32_t len); + uint32_t readAll(uint8_t* buf, uint32_t len); + void readEnd(); + void write(const uint8_t* buf, uint32_t len); + void writeEnd(); + void flush(); + + // TFileReaderTransport functions + int32_t getReadTimeout(); + void setReadTimeout(int32_t readTimeout); + uint32_t getNumChunks(); + uint32_t getCurChunk(); + void seekToChunk(int32_t chunk); + void seekToEnd(); + + protected: + // shouldn't be used + TPipedFileReaderTransport(); + boost::shared_ptr<TFileReaderTransport> srcTrans_; +}; + +/** + * Creates a TPipedFileReaderTransport from a filepath and a destination transport + * + */ +class TPipedFileReaderTransportFactory : public TPipedTransportFactory { + public: + TPipedFileReaderTransportFactory() {} + TPipedFileReaderTransportFactory(boost::shared_ptr<TTransport> dstTrans) + : TPipedTransportFactory(dstTrans) + {} + virtual ~TPipedFileReaderTransportFactory() {} + + boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> srcTrans) { + boost::shared_ptr<TFileReaderTransport> pFileReaderTransport = boost::dynamic_pointer_cast<TFileReaderTransport>(srcTrans); + if (pFileReaderTransport.get() != NULL) { + return getFileReaderTransport(pFileReaderTransport); + } else { + return boost::shared_ptr<TTransport>(); + } + } + + boost::shared_ptr<TFileReaderTransport> getFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans) { + return boost::shared_ptr<TFileReaderTransport>(new TPipedFileReaderTransport(srcTrans, dstTrans_)); + } +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ diff --git a/lib/cpp/src/transport/TZlibTransport.cpp b/lib/cpp/src/transport/TZlibTransport.cpp new file mode 100644 index 000000000..2f14e906b --- /dev/null +++ b/lib/cpp/src/transport/TZlibTransport.cpp @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cassert> +#include <cstring> +#include <algorithm> +#include <transport/TZlibTransport.h> +#include <zlib.h> + +using std::string; + +namespace apache { namespace thrift { namespace transport { + +// Don't call this outside of the constructor. +void TZlibTransport::initZlib() { + int rv; + bool r_init = false; + try { + rstream_ = new z_stream; + wstream_ = new z_stream; + + rstream_->zalloc = Z_NULL; + wstream_->zalloc = Z_NULL; + rstream_->zfree = Z_NULL; + wstream_->zfree = Z_NULL; + rstream_->opaque = Z_NULL; + wstream_->opaque = Z_NULL; + + rstream_->next_in = crbuf_; + wstream_->next_in = uwbuf_; + rstream_->next_out = urbuf_; + wstream_->next_out = cwbuf_; + rstream_->avail_in = 0; + wstream_->avail_in = 0; + rstream_->avail_out = urbuf_size_; + wstream_->avail_out = cwbuf_size_; + + rv = inflateInit(rstream_); + checkZlibRv(rv, rstream_->msg); + + // Have to set this flag so we know whether to de-initialize. + r_init = true; + + rv = deflateInit(wstream_, Z_DEFAULT_COMPRESSION); + checkZlibRv(rv, wstream_->msg); + } + + catch (...) { + if (r_init) { + rv = inflateEnd(rstream_); + checkZlibRvNothrow(rv, rstream_->msg); + } + // There is no way we can get here if wstream_ was initialized. + + throw; + } +} + +inline void TZlibTransport::checkZlibRv(int status, const char* message) { + if (status != Z_OK) { + throw TZlibTransportException(status, message); + } +} + +inline void TZlibTransport::checkZlibRvNothrow(int status, const char* message) { + if (status != Z_OK) { + string output = "TZlibTransport: zlib failure in destructor: " + + TZlibTransportException::errorMessage(status, message); + GlobalOutput(output.c_str()); + } +} + +TZlibTransport::~TZlibTransport() { + int rv; + rv = inflateEnd(rstream_); + checkZlibRvNothrow(rv, rstream_->msg); + rv = deflateEnd(wstream_); + checkZlibRvNothrow(rv, wstream_->msg); + + delete[] urbuf_; + delete[] crbuf_; + delete[] uwbuf_; + delete[] cwbuf_; + delete rstream_; + delete wstream_; +} + +bool TZlibTransport::isOpen() { + return (readAvail() > 0) || transport_->isOpen(); +} + +// READING STRATEGY +// +// We have two buffers for reading: one containing the compressed data (crbuf_) +// and one containing the uncompressed data (urbuf_). When read is called, +// we repeat the following steps until we have satisfied the request: +// - Copy data from urbuf_ into the caller's buffer. +// - If we had enough, return. +// - If urbuf_ is empty, read some data into it from the underlying transport. +// - Inflate data from crbuf_ into urbuf_. +// +// In standalone objects, we set input_ended_ to true when inflate returns +// Z_STREAM_END. This allows to make sure that a checksum was verified. + +inline int TZlibTransport::readAvail() { + return urbuf_size_ - rstream_->avail_out - urpos_; +} + +uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) { + int need = len; + + // TODO(dreiss): Skip urbuf on big reads. + + while (true) { + // Copy out whatever we have available, then give them the min of + // what we have and what they want, then advance indices. + int give = std::min(readAvail(), need); + memcpy(buf, urbuf_ + urpos_, give); + need -= give; + buf += give; + urpos_ += give; + + // If they were satisfied, we are done. + if (need == 0) { + return len; + } + + // If we get to this point, we need to get some more data. + + // If zlib has reported the end of a stream, we can't really do any more. + if (input_ended_) { + return len - need; + } + + // The uncompressed read buffer is empty, so reset the stream fields. + rstream_->next_out = urbuf_; + rstream_->avail_out = urbuf_size_; + urpos_ = 0; + + // If we don't have any more compressed data available, + // read some from the underlying transport. + if (rstream_->avail_in == 0) { + uint32_t got = transport_->read(crbuf_, crbuf_size_); + if (got == 0) { + return len - need; + } + rstream_->next_in = crbuf_; + rstream_->avail_in = got; + } + + // We have some compressed data now. Uncompress it. + int zlib_rv = inflate(rstream_, Z_SYNC_FLUSH); + + if (zlib_rv == Z_STREAM_END) { + if (standalone_) { + input_ended_ = true; + } + } else { + checkZlibRv(zlib_rv, rstream_->msg); + } + + // Okay. The read buffer should have whatever we can give it now. + // Loop back to the start and try to give some more. + } +} + + +// WRITING STRATEGY +// +// We buffer up small writes before sending them to zlib, so our logic is: +// - Is the write big? +// - Send the buffer to zlib. +// - Send this data to zlib. +// - Is the write small? +// - Is there insufficient space in the buffer for it? +// - Send the buffer to zlib. +// - Copy the data to the buffer. +// +// We have two buffers for writing also: the uncompressed buffer (mentioned +// above) and the compressed buffer. When sending data to zlib we loop over +// the following until the source (uncompressed buffer or big write) is empty: +// - Is there no more space in the compressed buffer? +// - Write the compressed buffer to the underlying transport. +// - Deflate from the source into the compressed buffer. + +void TZlibTransport::write(const uint8_t* buf, uint32_t len) { + // zlib's "deflate" function has enough logic in it that I think + // we're better off (performance-wise) buffering up small writes. + if ((int)len > MIN_DIRECT_DEFLATE_SIZE) { + flushToZlib(uwbuf_, uwpos_); + uwpos_ = 0; + flushToZlib(buf, len); + } else if (len > 0) { + if (uwbuf_size_ - uwpos_ < (int)len) { + flushToZlib(uwbuf_, uwpos_); + uwpos_ = 0; + } + memcpy(uwbuf_ + uwpos_, buf, len); + uwpos_ += len; + } +} + +void TZlibTransport::flush() { + flushToZlib(uwbuf_, uwpos_, true); + assert((int)wstream_->avail_out != cwbuf_size_); + transport_->write(cwbuf_, cwbuf_size_ - wstream_->avail_out); + transport_->flush(); +} + +void TZlibTransport::flushToZlib(const uint8_t* buf, int len, bool finish) { + int flush = (finish ? Z_FINISH : Z_NO_FLUSH); + + wstream_->next_in = const_cast<uint8_t*>(buf); + wstream_->avail_in = len; + + while (wstream_->avail_in > 0 || finish) { + // If our ouput buffer is full, flush to the underlying transport. + if (wstream_->avail_out == 0) { + transport_->write(cwbuf_, cwbuf_size_); + wstream_->next_out = cwbuf_; + wstream_->avail_out = cwbuf_size_; + } + + int zlib_rv = deflate(wstream_, flush); + + if (finish && zlib_rv == Z_STREAM_END) { + assert(wstream_->avail_in == 0); + break; + } + + checkZlibRv(zlib_rv, wstream_->msg); + } +} + +const uint8_t* TZlibTransport::borrow(uint8_t* buf, uint32_t* len) { + // Don't try to be clever with shifting buffers. + // If we have enough data, give a pointer to it, + // otherwise let the protcol use its slow path. + if (readAvail() >= (int)*len) { + *len = (uint32_t)readAvail(); + return urbuf_ + urpos_; + } + return NULL; +} + +void TZlibTransport::consume(uint32_t len) { + if (readAvail() >= (int)len) { + urpos_ += len; + } else { + throw TTransportException(TTransportException::BAD_ARGS, + "consume did not follow a borrow."); + } +} + +void TZlibTransport::verifyChecksum() { + if (!standalone_) { + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport can only verify checksums for standalone objects."); + } + + if (!input_ended_) { + // This should only be called when reading is complete, + // but it's possible that the whole checksum has not been fed to zlib yet. + // We try to read an extra byte here to force zlib to finish the stream. + // It might not always be easy to "unread" this byte, + // but we throw an exception if we get it, which is not really + // a recoverable error, so it doesn't matter. + uint8_t buf[1]; + uint32_t got = this->read(buf, sizeof(buf)); + if (got || !input_ended_) { + throw TTransportException( + TTransportException::CORRUPTED_DATA, + "Zlib stream not complete."); + } + } + + // If the checksum had been bad, we would have gotten an error while + // inflating. +} + + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TZlibTransport.h b/lib/cpp/src/transport/TZlibTransport.h new file mode 100644 index 000000000..1439d9de7 --- /dev/null +++ b/lib/cpp/src/transport/TZlibTransport.h @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1 + +#include <boost/lexical_cast.hpp> +#include <transport/TTransport.h> + +struct z_stream_s; + +namespace apache { namespace thrift { namespace transport { + +class TZlibTransportException : public TTransportException { + public: + TZlibTransportException(int status, const char* msg) : + TTransportException(TTransportException::INTERNAL_ERROR, + errorMessage(status, msg)), + zlib_status_(status), + zlib_msg_(msg == NULL ? "(null)" : msg) {} + + virtual ~TZlibTransportException() throw() {} + + int getZlibStatus() { return zlib_status_; } + std::string getZlibMessage() { return zlib_msg_; } + + static std::string errorMessage(int status, const char* msg) { + std::string rv = "zlib error: "; + if (msg) { + rv += msg; + } else { + rv += "(no message)"; + } + rv += " (status = "; + rv += boost::lexical_cast<std::string>(status); + rv += ")"; + return rv; + } + + int zlib_status_; + std::string zlib_msg_; +}; + +/** + * This transport uses zlib's compressed format on the "far" side. + * + * There are two kinds of TZlibTransport objects: + * - Standalone objects are used to encode self-contained chunks of data + * (like structures). They include checksums. + * - Non-standalone transports are used for RPC. They are not implemented yet. + * + * TODO(dreiss): Don't do an extra copy of the compressed data if + * the underlying transport is TBuffered or TMemory. + * + */ +class TZlibTransport : public TTransport { + public: + + /** + * @param transport The transport to read compressed data from + * and write compressed data to. + * @param use_for_rpc True if this object will be used for RPC, + * false if this is a standalone object. + * @param urbuf_size Uncompressed buffer size for reading. + * @param crbuf_size Compressed buffer size for reading. + * @param uwbuf_size Uncompressed buffer size for writing. + * @param cwbuf_size Compressed buffer size for writing. + * + * TODO(dreiss): Write a constructor that isn't a pain. + */ + TZlibTransport(boost::shared_ptr<TTransport> transport, + bool use_for_rpc, + int urbuf_size = DEFAULT_URBUF_SIZE, + int crbuf_size = DEFAULT_CRBUF_SIZE, + int uwbuf_size = DEFAULT_UWBUF_SIZE, + int cwbuf_size = DEFAULT_CWBUF_SIZE) : + transport_(transport), + standalone_(!use_for_rpc), + urpos_(0), + uwpos_(0), + input_ended_(false), + output_flushed_(false), + urbuf_size_(urbuf_size), + crbuf_size_(crbuf_size), + uwbuf_size_(uwbuf_size), + cwbuf_size_(cwbuf_size), + urbuf_(NULL), + crbuf_(NULL), + uwbuf_(NULL), + cwbuf_(NULL), + rstream_(NULL), + wstream_(NULL) + { + + if (!standalone_) { + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport has not been tested for RPC."); + } + + if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) { + // Have to copy this into a local because of a linking issue. + int minimum = MIN_DIRECT_DEFLATE_SIZE; + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport: uncompressed write buffer must be at least" + + boost::lexical_cast<std::string>(minimum) + "."); + } + + try { + urbuf_ = new uint8_t[urbuf_size]; + crbuf_ = new uint8_t[crbuf_size]; + uwbuf_ = new uint8_t[uwbuf_size]; + cwbuf_ = new uint8_t[cwbuf_size]; + + // Don't call this outside of the constructor. + initZlib(); + + } catch (...) { + delete[] urbuf_; + delete[] crbuf_; + delete[] uwbuf_; + delete[] cwbuf_; + throw; + } + } + + // Don't call this outside of the constructor. + void initZlib(); + + ~TZlibTransport(); + + bool isOpen(); + + void open() { + transport_->open(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void write(const uint8_t* buf, uint32_t len); + + void flush(); + + const uint8_t* borrow(uint8_t* buf, uint32_t* len); + + void consume(uint32_t len); + + void verifyChecksum(); + + /** + * TODO(someone_smart): Choose smart defaults. + */ + static const int DEFAULT_URBUF_SIZE = 128; + static const int DEFAULT_CRBUF_SIZE = 1024; + static const int DEFAULT_UWBUF_SIZE = 128; + static const int DEFAULT_CWBUF_SIZE = 1024; + + protected: + + inline void checkZlibRv(int status, const char* msg); + inline void checkZlibRvNothrow(int status, const char* msg); + inline int readAvail(); + void flushToZlib(const uint8_t* buf, int len, bool finish = false); + + // Writes smaller than this are buffered up. + // Larger (or equal) writes are dumped straight to zlib. + static const int MIN_DIRECT_DEFLATE_SIZE = 32; + + boost::shared_ptr<TTransport> transport_; + bool standalone_; + + int urpos_; + int uwpos_; + + /// True iff zlib has reached the end of a stream. + /// This is only ever true in standalone protcol objects. + bool input_ended_; + /// True iff we have flushed the output stream. + /// This is only ever true in standalone protcol objects. + bool output_flushed_; + + int urbuf_size_; + int crbuf_size_; + int uwbuf_size_; + int cwbuf_size_; + + uint8_t* urbuf_; + uint8_t* crbuf_; + uint8_t* uwbuf_; + uint8_t* cwbuf_; + + struct z_stream_s* rstream_; + struct z_stream_s* wstream_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ |