diff options
author | Tyler Seip <Tyler.Seip@mongodb.com> | 2021-10-29 20:44:40 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-10-29 21:27:26 +0000 |
commit | 786482da93c3e5e58b1c690cb060f00c60864f69 (patch) | |
tree | 9d63c02f889251d2de875e71713f6d53fa2b2040 /src/mongo/transport | |
parent | 6ed9f24dc8470f6f7010f7b5cf6dba8db7c9e376 (diff) | |
download | mongo-786482da93c3e5e58b1c690cb060f00c60864f69.tar.gz |
SERVER-60677: Implement parser for Proxy Protocol V1 and V2 headers
Diffstat (limited to 'src/mongo/transport')
-rw-r--r-- | src/mongo/transport/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/transport/proxy_protocol_header_parser.cpp | 453 | ||||
-rw-r--r-- | src/mongo/transport/proxy_protocol_header_parser.h | 105 | ||||
-rw-r--r-- | src/mongo/transport/proxy_protocol_header_parser_test.cpp | 573 |
4 files changed, 1133 insertions, 0 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index b166e728719..295bf0f5680 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -55,6 +55,7 @@ tlEnv.Library( 'transport_layer_asio.cpp', 'asio_utils.cpp', 'session_asio.cpp', + 'proxy_protocol_header_parser.cpp', 'transport_options.idl', ], LIBDEPS=[ @@ -186,6 +187,7 @@ tlEnv.CppUnitTest( 'service_executor_test.cpp', 'max_conns_override_test.cpp', 'service_state_machine_test.cpp', + 'proxy_protocol_header_parser_test.cpp', ], LIBDEPS=[ '$BUILD_DIR/mongo/base', diff --git a/src/mongo/transport/proxy_protocol_header_parser.cpp b/src/mongo/transport/proxy_protocol_header_parser.cpp new file mode 100644 index 00000000000..2a0b672f924 --- /dev/null +++ b/src/mongo/transport/proxy_protocol_header_parser.cpp @@ -0,0 +1,453 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork + +#include "mongo/transport/proxy_protocol_header_parser.h" + +#include <boost/optional.hpp> +#include <cstring> +#include <exception> +#include <fmt/format.h> + +#ifndef _WIN32 +#include <netinet/in.h> +#include <sys/un.h> +#endif + +#include "mongo/base/parse_number.h" +#include "mongo/base/string_data.h" +#include "mongo/logv2/log.h" +#include "mongo/platform/endian.h" +#include "mongo/util/assert_util.h" + +namespace mongo::transport { + +using namespace fmt::literals; + +namespace { +StringData parseToken(StringData& s, const char c) { + size_t pos = s.find(c); + uassert(ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string malformed: {}"_format(s), + pos != std::string::npos); + StringData result = s.substr(0, pos); + s = s.substr(pos + 1); + return result; +} +} // namespace + +namespace proxy_protocol_details { + +void validateIpv4Address(StringData addr) { + StringData buffer = addr; + const NumberParser octetParser = + NumberParser().skipWhitespace(false).base(10).allowTrailingText(false); + try { + for (size_t i = 0; i < 4; ++i) { + unsigned octet = 0; + if (i == 3) { + uassertStatusOK(octetParser(buffer, &octet)); + } else { + uassertStatusOK(octetParser(parseToken(buffer, '.'), &octet)); + } + uassert( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string specified malformed IPv4 address: {}"_format( + addr), + octet <= 255); + } + } catch (const ExceptionFor<ErrorCodes::FailedToParse>&) { + uasserted( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string specified malformed IPv4 address: {}"_format( + addr)); + } +} + +void validateIpv6Address(StringData addr) { + static constexpr StringData doubleColon = "::"_sd; + + auto validateHexadectets = [](StringData buffer) -> size_t { + auto validateHexadectet = [](StringData hexadectet) { + const NumberParser hexadectetParser = + NumberParser().skipWhitespace(false).base(16).allowTrailingText(false); + unsigned value = 0; + uassertStatusOK(hexadectetParser(hexadectet, &value)); + uassert( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string contains malformed IPv6 hexadectet: {}"_format( + hexadectet), + hexadectet.size() == 4 && value >= 0); + }; + + if (buffer.empty()) + return 0; + + uassert( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string contains malformed IPv6 hexadectet: {}"_format( + buffer), + buffer.find(doubleColon) == std::string::npos); + + size_t numHexadectets = 0; + while (!buffer.empty()) { + if (const size_t pos = buffer.find(':'); pos != std::string::npos) { + validateHexadectet(buffer.substr(0, pos)); + ++numHexadectets; + buffer = buffer.substr(pos + 1); + } else { + validateHexadectet(buffer); + return numHexadectets + 1; + } + } + uasserted( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string contains malformed IPv6 hexadectet: {}"_format( + buffer)); + }; + + // There can be at most one double colon in our address. Split on the first + // one and validate neither half has another implicitly. + try { + if (const auto pos = addr.find(doubleColon); pos != std::string::npos) { + const size_t numHexadectets = validateHexadectets(addr.substr(0, pos)) + + validateHexadectets(addr.substr(pos + doubleColon.size())); + uassert( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string specified malformed IPv6 address: {}"_format( + addr), + numHexadectets < 8); + } else { + const size_t numHexadectets = validateHexadectets(addr); + uassert( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string specified malformed IPv6 address: {}"_format( + addr), + numHexadectets == 8); + } + } catch (const ExceptionFor<ErrorCodes::FailedToParse>&) { + uasserted( + ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string specified malformed IPv6 address: {}"_format( + addr)); + } +} + +} // namespace proxy_protocol_details + +namespace { + +// Interprets the first sizeof(T) bytes of data as a T and returns it, advancing the data cursor +// by the same amount. Does not account for endianness of the data buffer. +template <typename T> +T extract(StringData& data) { + MONGO_STATIC_ASSERT(std::is_trivially_copyable_v<T>); + static constexpr size_t numBytes = sizeof(T); + if (data.size() < numBytes) { + throw std::out_of_range("Not enough space to extract object of size {}"_format(numBytes)); + } + + T result; + memcpy(&result, data.rawData(), numBytes); + data = data.substr(numBytes); + return result; +} + +constexpr StringData kV1Start = "PROXY"_sd; + +bool parseV1Buffer(StringData& buffer, boost::optional<ProxiedEndpoints>& endpoints) { + buffer = buffer.substr(kV1Start.size()); + if (buffer.empty()) + return false; + + // Scan the buffer for a newline and prepare an output buffer which begins just past + // the line. + static constexpr StringData crlf = "\x0D\x0A"_sd; + const auto crlfPos = buffer.find(crlf); + + static constexpr size_t kMaximumV1HeaderSize = 107; + static constexpr size_t kMaximumV1InetLineSize = kMaximumV1HeaderSize - kV1Start.size(); + if (crlfPos == std::string::npos) { + // If we couldn't find a newline sequence, then fail if there cannot be enough room + // for one to appear in the future. + uassert(ErrorCodes::FailedToParse, + "No terminating newline found in Proxy Protocol header V1: {}"_format(buffer), + buffer.size() <= kMaximumV1InetLineSize); + return false; + } else { + // If we could, then fail if the sequence doesn't occur within the maximum line length. + uassert(ErrorCodes::FailedToParse, + "No terminating newline found in Proxy Protocol header V1: {}"_format(buffer), + crlfPos + crlf.size() <= kMaximumV1InetLineSize); + } + + // Prepare a result buffer pointing to just after the crlf sequence. + const auto resultBuffer = buffer.substr(crlfPos + crlf.size()); + + static constexpr StringData kTcp4Prefix = " TCP4 "_sd; + static constexpr StringData kTcp6Prefix = " TCP6 "_sd; + int aFamily = AF_UNSPEC; + if (buffer.startsWith(kTcp4Prefix)) { + aFamily = AF_INET; + buffer = buffer.substr(kTcp4Prefix.size()); + } else if (buffer.startsWith(kTcp6Prefix)) { + aFamily = AF_INET6; + buffer = buffer.substr(kTcp6Prefix.size()); + } else if (buffer.startsWith(" UNKNOWN"_sd)) { + buffer = resultBuffer; + endpoints = {}; + return true; + } else { + uasserted(ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string malformed: {}"_format(buffer)); + } + + // The remainder of the string should now tokenize into four substrings: + // srcAddr dstAddr srcPort dstPort + const StringData srcAddr = parseToken(buffer, ' '); + const StringData dstAddr = parseToken(buffer, ' '); + + invariant(aFamily == AF_INET || aFamily == AF_INET6); + if (aFamily == AF_INET) { + proxy_protocol_details::validateIpv4Address(srcAddr); + proxy_protocol_details::validateIpv4Address(dstAddr); + } else { + proxy_protocol_details::validateIpv6Address(srcAddr); + proxy_protocol_details::validateIpv6Address(dstAddr); + } + + const StringData srcPortStr = parseToken(buffer, ' '); + const StringData dstPortStr = parseToken(buffer, '\r'); + + const NumberParser portParser = + NumberParser().skipWhitespace(false).base(10).allowTrailingText(false); + unsigned srcPort, dstPort = 0; + uassertStatusOK(portParser(srcPortStr, &srcPort)); + uassertStatusOK(portParser(dstPortStr, &dstPort)); + + auto validatePort = [](int port) { + uassert(ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string specified invalid port: {}"_format(port), + port <= 65535); + }; + validatePort(srcPort); + validatePort(dstPort); + + buffer = resultBuffer; + try { + endpoints = ProxiedEndpoints{SockAddr::create(srcAddr, srcPort, aFamily), + SockAddr::create(dstAddr, dstPort, aFamily)}; + return true; + } catch (const ExceptionFor<ErrorCodes::HostUnreachable>&) { + // SockAddr can throw on construction if the address passed in is malformed. + uasserted(ErrorCodes::FailedToParse, + "Proxy Protocol Version 1 address string specified unreachable host: {}"_format( + buffer)); + } +} + +// Since this string contains a null, it's critical we use a literal here. +constexpr StringData kV2Start = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"_sd; + +bool parseV2Buffer(StringData& buffer, boost::optional<ProxiedEndpoints>& endpoints) { + buffer = buffer.substr(kV2Start.size()); + if (buffer.empty()) + return false; + + try { + const char protocolVersionAndCommandByte = extract<char>(buffer); + + bool isLocal = false; + // The high nibble must be 2 (version 2) and the low nibble must be either 0 or 1 (local or + // remote). + if (protocolVersionAndCommandByte == '\x20') { + isLocal = true; + } else if (protocolVersionAndCommandByte != '\x21') { + uasserted( + ErrorCodes::FailedToParse, + "Invalid version or command byte given in Proxy Protocol header V2: {}"_format( + protocolVersionAndCommandByte)); + } + + if (buffer.empty()) + return false; + + const uint8_t transportProtocolAndAddressFamilyByte = extract<uint8_t>(buffer); + + int aFamily = 0; + { + // Discard the family if this is a local connection. + uint8_t aFamilyByte = isLocal ? 0 : (transportProtocolAndAddressFamilyByte & 0xF0) >> 4; + switch (aFamilyByte) { + case 0: + aFamily = AF_UNSPEC; + break; + case 1: + aFamily = AF_INET; + break; + case 2: + aFamily = AF_INET6; + break; + case 3: + aFamily = AF_UNIX; + break; + default: + uasserted(ErrorCodes::FailedToParse, + "Invalid address family given in Proxy Protocol header V2: {}"_format( + aFamilyByte)); + } + } + + uint8_t protocol = (transportProtocolAndAddressFamilyByte & 0xF); + uassert(ErrorCodes::FailedToParse, + "Invalid protocol given in Proxy Protocol header V2: {}"_format(protocol), + protocol <= 0x2); + + // If protocol is unspecified, we should also ignore address information. + if (protocol == 0) { + aFamily = AF_UNSPEC; + } + + if (buffer.size() < sizeof(uint16_t)) + return false; + + const size_t length = endian::bigToNative(extract<uint16_t>(buffer)); + if (buffer.size() < length) + return false; + + // Prepare an output buffer that skips past the end of the header. + // We'll assign this to the buffer if we fully succeed in parsing the header. + const auto resultBuffer = buffer.substr(length); + + switch (aFamily) { + case AF_UNSPEC: + break; + case AF_INET: { + // The proxy protocol allocates 12 bytes to represent a pair of IPv4 addresses + // along with their ports. + static constexpr size_t kIPv4ProxyProtocolSize = 12; + MONGO_STATIC_ASSERT(2 * (sizeof(in_addr) + sizeof(uint16_t)) == + kIPv4ProxyProtocolSize); + uassert(ErrorCodes::FailedToParse, + "Proxy Protocol Version 2 address string too short: {}"_format(buffer), + length >= kIPv4ProxyProtocolSize); + sockaddr_in src_addr{}; + sockaddr_in dst_addr{}; + src_addr.sin_family = dst_addr.sin_family = AF_INET; + // These are specified by the protocol to be in network byte order, which + // is what sin_addr/sin_port expect, so we copy them directly. + src_addr.sin_addr = extract<in_addr>(buffer); + dst_addr.sin_addr = extract<in_addr>(buffer); + src_addr.sin_port = extract<uint16_t>(buffer); + dst_addr.sin_port = extract<uint16_t>(buffer); + endpoints = ProxiedEndpoints{SockAddr((sockaddr*)&src_addr, sizeof(sockaddr_in)), + SockAddr((sockaddr*)&dst_addr, sizeof(sockaddr_in))}; + break; + } + case AF_INET6: { + // The proxy protocol allocates 36 bytes to represent a pair of IPv6 addresses + // along with their ports. + static constexpr size_t kIPv6ProxyProtocolSize = 36; + MONGO_STATIC_ASSERT(2 * (sizeof(in6_addr) + sizeof(uint16_t)) == + kIPv6ProxyProtocolSize); + uassert(ErrorCodes::FailedToParse, + "Proxy Protocol Version 2 address string too short: {}"_format(buffer), + length >= kIPv6ProxyProtocolSize); + sockaddr_in6 src_addr{}; + sockaddr_in6 dst_addr{}; + src_addr.sin6_family = dst_addr.sin6_family = AF_INET6; + // These are specified by the protocol to be in network byte order, which + // is what sin_addr/sin_port expect, so we copy them directly. + src_addr.sin6_addr = extract<in6_addr>(buffer); + dst_addr.sin6_addr = extract<in6_addr>(buffer); + src_addr.sin6_port = extract<uint16_t>(buffer); + dst_addr.sin6_port = extract<uint16_t>(buffer); + endpoints = ProxiedEndpoints{SockAddr((sockaddr*)&src_addr, sizeof(sockaddr_in6)), + SockAddr((sockaddr*)&dst_addr, sizeof(sockaddr_in6))}; + break; + } + case AF_UNIX: { + // The proxy protocol allocates 216 bytes to represent a pair of UNIX address, + // but we don't assert type sizes here because some platforms don't support + // UNIX addresses of this length - they are checked in parseSockAddrUn. + static constexpr size_t kUnixProxyProtocolSize = 216; + uassert(ErrorCodes::FailedToParse, + "Proxy Protocol Version 2 address string too short: {}"_format(buffer), + length >= kUnixProxyProtocolSize); + const auto src_addr = proxy_protocol_details::parseSockAddrUn( + buffer.substr(0, proxy_protocol_details::kMaxUnixPathLength)); + const auto dst_addr = proxy_protocol_details::parseSockAddrUn( + buffer.substr(proxy_protocol_details::kMaxUnixPathLength, + proxy_protocol_details::kMaxUnixPathLength)); + + endpoints = ProxiedEndpoints{SockAddr((sockaddr*)&src_addr, sizeof(sockaddr_un)), + SockAddr((sockaddr*)&dst_addr, sizeof(sockaddr_un))}; + break; + } + default: + MONGO_UNREACHABLE; + } + buffer = resultBuffer; + return true; + } catch (const std::out_of_range&) { + return false; + } +} + +} // namespace + +boost::optional<ParserResults> parseProxyProtocolHeader(StringData buffer) { + // Check if the buffer presented is V1, V2, or neither. + const size_t originalBufferSize = buffer.size(); + + ParserResults results; + bool complete = false; + if (buffer.startsWith(kV1Start)) { + complete = parseV1Buffer(buffer, results.endpoints); + } else if (buffer.startsWith(kV2Start)) { + complete = parseV2Buffer(buffer, results.endpoints); + } else { + uassert(ErrorCodes::FailedToParse, + "Initial Proxy Protocol header bytes invalid: {}" + "; Make sure your proxy is configured to emit a Proxy " + "Protocol header"_format(buffer), + kV1Start.startsWith(buffer) || kV2Start.startsWith(buffer)); + } + + if (complete) { + results.bytesParsed = originalBufferSize - buffer.size(); + return results; + } else { + return {}; + } +} + + +} // namespace mongo::transport diff --git a/src/mongo/transport/proxy_protocol_header_parser.h b/src/mongo/transport/proxy_protocol_header_parser.h new file mode 100644 index 00000000000..f3a3e6126ab --- /dev/null +++ b/src/mongo/transport/proxy_protocol_header_parser.h @@ -0,0 +1,105 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <algorithm> +#include <boost/optional.hpp> +#include <fmt/format.h> + +#ifndef _WIN32 +#include <sys/un.h> +#endif + +#include "mongo/base/error_codes.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/net/sockaddr.h" + +namespace mongo::transport { + +/** + * Represents the true endpoints that a proxy using the Proxy Protocol is proxying for us. + */ +struct ProxiedEndpoints { + // The true origin of the connection, i.e. the IP address of the client behind the + // proxy. + SockAddr sourceAddress; + // The true destination of the connection, almost always the address that the proxy + // is listening on. + SockAddr destinationAddress; +}; + +/** + * Contains the results of parsing a Proxy Protocol header. bytesParsed contains the + * length of the parsed header, and endpoints contains any endpoint information that the + * header optionally contained. + */ +struct ParserResults { + // The endpoint metadata should be populated iff parsing is complete, the connection + // is marked as remote, and the connection is not marked as UNKNOWN. + boost::optional<ProxiedEndpoints> endpoints = {}; + size_t bytesParsed = 0; +}; + +/** + * Parses a string potentially starting with a proxy protocol header (either V1 or V2). + * If the string begins with a partial but incomplete header, returns an empty optional; + * otherwise, returns a ParserResults with the results of the parse. + * + * Will throw eagerly on a malformed header. + */ +boost::optional<ParserResults> parseProxyProtocolHeader(StringData buffer); + +namespace proxy_protocol_details { +// The maximum number of bytes ever needed by a proxy protocol header; represents +// the minimum TCP MTU. +static constexpr size_t kBytesToFetch = 536; +static constexpr size_t kMaxUnixPathLength = 108; + +template <typename AddrUn = sockaddr_un> +AddrUn parseSockAddrUn(StringData buffer) { + using namespace fmt::literals; + + AddrUn addr{}; + addr.sun_family = AF_UNIX; + + StringData path = buffer.substr(0, buffer.find('\0')); + uassert(ErrorCodes::FailedToParse, + "Provided unix path longer than system supports: {}"_format(buffer), + path.size() < sizeof(AddrUn::sun_path)); + std::copy(path.begin(), path.end(), addr.sun_path); + return addr; +} + +void validateIpv4Address(StringData addr); +void validateIpv6Address(StringData addr); + +} // namespace proxy_protocol_details + +} // namespace mongo::transport diff --git a/src/mongo/transport/proxy_protocol_header_parser_test.cpp b/src/mongo/transport/proxy_protocol_header_parser_test.cpp new file mode 100644 index 00000000000..8de76feb091 --- /dev/null +++ b/src/mongo/transport/proxy_protocol_header_parser_test.cpp @@ -0,0 +1,573 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/transport/proxy_protocol_header_parser.h" + +#include "mongo/unittest/assert_that.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/shared_buffer.h" + +namespace mongo::transport { +namespace { + +using namespace unittest::match; +using namespace fmt::literals; + +template <typename MSrc, typename MDst> +class ProxiedEndpointsAre : public Matcher { +public: + explicit ProxiedEndpointsAre(MSrc&& s, MDst&& d) : _src(std::move(s)), _dst(std::move(d)) {} + + std::string describe() const { + return "ProxiedEndpointsAre({}, {})"_format(_src.describe(), _dst.describe()); + } + + MatchResult match(const ProxiedEndpoints& e) const { + return StructuredBindingsAre<MSrc, MDst>(_src, _dst).match(e); + } + +private: + MSrc _src; + MDst _dst; +}; + +ParserResults parseAllPrefixes(StringData s) { + boost::optional<ParserResults> results; + for (size_t len = 0; len <= s.size(); ++len) { + StringData sub = s.substr(0, len); + results = parseProxyProtocolHeader(sub); + if (len < s.size()) { + ASSERT_FALSE(results) << "size={}, sub={}"_format(len, sub); + } + } + ASSERT_TRUE(results); + return *results; +} + +void parseStringExpectFailure(StringData s, std::string regex) { + try { + parseAllPrefixes(s); + FAIL("Expected to throw"); + } catch (const DBException& ex) { + ASSERT_THAT(ex.toStatus(), StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex(regex))); + } +} + +boost::optional<ProxiedEndpoints> parseStringExpectSuccess(StringData s) { + const ParserResults results = parseAllPrefixes(s); + ASSERT_THAT(results.bytesParsed, Eq(s.size())); + + // Also test that adding garbage to the end doesn't increase the bytesParsed amount. + const boost::optional<ParserResults> possibleResultsWithGarbage = + parseProxyProtocolHeader(s + "garbage"); + ASSERT_TRUE(possibleResultsWithGarbage); + const ParserResults resultsWithGarbage = *possibleResultsWithGarbage; + ASSERT_THAT(resultsWithGarbage.bytesParsed, Eq(s.size())); + if (results.endpoints) { + ASSERT_THAT(*results.endpoints, + ProxiedEndpointsAre(Eq(resultsWithGarbage.endpoints->sourceAddress), + Eq(resultsWithGarbage.endpoints->destinationAddress))); + } else { + ASSERT_FALSE(resultsWithGarbage.endpoints); + } + + return resultsWithGarbage.endpoints; +} + +TEST(ProxyProtocolHeaderParser, MalformedIpv4Addresses) { + StringData testCases[] = {"1", + "1.1", + "1.1.1", + "1.1.1.1.1", + "1.1.1.1.", + ".1.1.1.1", + "1234.1.1.1", + "1.1234.1.1", + "1.1.1234.1", + "1.1.1.1234", + "1.1.1.a", + "1.1.1.256", + "256.1.1.1", + "1.1..1.1", + "-0.1.1.1", + "-1.1.1.1", + ""}; + + for (const auto& testCase : testCases) { + try { + proxy_protocol_details::validateIpv4Address(testCase); + FAIL("Expected to throw"); + } catch (const DBException& ex) { + ASSERT_THAT(ex.toStatus(), + StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("malformed"))); + } + } +} + +TEST(ProxyProtocolHeaderParser, WellFormedIpv4Addresses) { + StringData testCases[] = { + "1.1.1.1", "0.0.0.0", "255.255.255.255", "0.255.0.255", "127.0.1.1", "1.12.123.0"}; + + for (const auto& testCase : testCases) { + proxy_protocol_details::validateIpv4Address(testCase); + } +} + +TEST(ProxyProtocolHeaderParser, MalformedIpv6Addresses) { + StringData testCases[] = {"0000", + "0000:0000", + "0000:0000:0000", + "0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000:0000:", + ":0000:0000:0000:0000:0000:0000:0000", + "00000:0000:0000:0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000:0000:00000", + "0000:-0000:0000:0000:0000:0000:0000:0000", + "0000:-000:0000:0000:0000:0000:0000:0000", + "000g:0000:0000:0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000:0000:000g", + "0000::0000:0000:0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000:0000::0000", + "0000:0000:0000:0000:0000:0000:0000:0000::", + "::0000:0000:0000:0000:0000:0000:0000:0000", + "::0000::", + "0000::0000::0000:0000:0000:0000:0000", + "0000::0000::0000:0000:0000:0000:0000:0000", + "::0000:", + ":0000::", + ":::", + ""}; + + for (const auto& testCase : testCases) { + try { + proxy_protocol_details::validateIpv6Address(testCase); + FAIL("Expected to throw"); + } catch (const DBException& ex) { + ASSERT_THAT(ex.toStatus(), + StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("malformed"))); + } + } +} + +TEST(ProxyProtocolHeaderParser, WellFormedIpv6Addresses) { + StringData testCases[] = {"::", + "::0000", + "::0000:0000", + "::0000:0000:0000", + "::0000:0000:0000:0000", + "::0000:0000:0000:0000:0000", + "::0000:0000:0000:0000:0000:0000", + "::0000:0000:0000:0000:0000:0000:0000", + "0000:0000:0000:0000:0000:0000:0000::", + "0000:0000:0000:0000:0000:0000::", + "0000:0000:0000:0000:0000::", + "0000:0000:0000:0000::", + "0000:0000:0000::", + "0000:0000::", + "0000::", + "0000::0000", + "0000::0000:0000", + "0000::0000:0000:0000", + "0000::0000:0000:0000:0000", + "0000::0000:0000:0000:0000:0000", + "0000::0000:0000:0000:0000:0000:0000", + "0000:0000:0000::0000:0000:0000", + "0000:0000:0000::0000:0000", + "0000:0000:0000:0000:0000:0000:0000:0000", + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "ffff::", + "0123:4567:89ab:cdef::", + "::0123:4567:89ab:cdef", + "0123:4567::89ab:cdef"}; + + for (const auto& testCase : testCases) { + proxy_protocol_details::validateIpv6Address(testCase); + } +} + +TEST(ProxyProtocolHeaderParser, MalformedV1Headers) { + std::pair<std::string, std::string> testCases[] = { + {"PORXY ", "header bytes invalid"}, + + {"PROXY " + std::string(200, '1'), "No terminating newline"}, + + // Even if there is a terminating newline, it has to happen before the longest possible + // header length is seen. + {"PROXY UNKNOWN " + std::string(92, '1') + "\r\n", "No terminating newline"}, + + {"PROXY " + std::string(50, '\r') + "1" + "\r\r\n", "address string malformed"}, + {"PROXY TCP4 \r\n", "address string malformed"}, + {"PROXY TCP4 1.1.1.1\r\n", "address string malformed"}, + {"PROXY TCP4 1.1.1.1 1.1.1.1 10\r\n", "address string malformed"}, + + {"PROXY TCP4 12800000000 28 10 10\r\n", "malformed IPv4"}, + {"PROXY TCP4 128 28000000000000 10 10\r\n", "malformed IPv4"}, + {"PROXY TCP4 1.1.1.1 notanip 10 300\r\n", "malformed IPv4"}, + {"PROXY TCP4 a:b:c:d 20 10 300\r\n", "malformed IPv4"}, + {"PROXY TCP4 1.1.1.1 1.1.1.1 -10 300\r\n", "Negative"}, + {"PROXY TCP4 1.1.1.1 1.1.1.1 10 -300\r\n", "Negative"}, + {"PROXY TCP4 1.1.1.1 2.2.2.2 notaport 10\r\n", "Did not consume"}, + {"PROXY TCP4 1.1.1.1 2.2.2.2 10 20garbage\r\n", "Did not consume"}, + + // Check TCP6 + {"PROXY TCP6 \r\n", "address string malformed"}, + {"PROXY TCP6 ::\r\n", "address string malformed"}, + {"PROXY TCP6 :: :: 10\r\n", "address string malformed"}, + + {"PROXY TCP6 1.1.1.1 2.2.2.2 10 10\r\n", "malformed IPv6"}, + {"PROXY TCP6 :: ::000g 10 10\r\n", "malformed IPv6"}, + {"PROXY TCP6 :: :: -10 10\r\n", "Negative"}, + {"PROXY TCP6 :: :: 10 -10\r\n", "Negative"}, + {"PROXY TCP6 :: notanip 10 300\r\n", "malformed IPv6"}, + {"PROXY TCP6 :: :: notaport 10\r\n", "Did not consume"}, + {"PROXY TCP6 :: :: 10 20garbage\r\n", "Did not consume"}}; + + for (const auto& testCase : testCases) { + parseStringExpectFailure(testCase.first, testCase.second); + } +} + +TEST(ProxyProtocolHeaderParser, WellFormedV1Headers) { + ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP4 1.1.1.1 2.2.2.2 10 300\r\n"), + ProxiedEndpointsAre(Eq(SockAddr::create("1.1.1.1", 10, AF_INET)), + Eq(SockAddr::create("2.2.2.2", 300, AF_INET)))); + + ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP4 0.0.0.128 0.0.1.44 1000 3000\r\n"), + ProxiedEndpointsAre(Eq(SockAddr::create("0.0.0.128", 1000, AF_INET)), + Eq(SockAddr::create("0.0.1.44", 3000, AF_INET)))); + + static constexpr StringData allFs = "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"_sd; + ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP6 {} " + "{} 10000 30000\r\n"_format(allFs, allFs)), + ProxiedEndpointsAre(Eq(SockAddr::create(allFs, 10000, AF_INET6)), + Eq(SockAddr::create(allFs, 30000, AF_INET6)))); + + ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP6 :: {} 1000 3000\r\n"_format(allFs)), + ProxiedEndpointsAre(Eq(SockAddr::create("::", 1000, AF_INET6)), + Eq(SockAddr::create(allFs, 3000, AF_INET6)))); + + ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP6 2001:0db8:: 0064:ff9b::0000 1000 3000\r\n"), + ProxiedEndpointsAre(Eq(SockAddr::create("2001:db8::", 1000, AF_INET6)), + Eq(SockAddr::create("64:ff9b::0000", 3000, AF_INET6)))); + + // The shortest possible V1 header + ASSERT_FALSE(parseStringExpectSuccess("PROXY UNKNOWN\r\n")); + ASSERT_FALSE( + parseStringExpectSuccess("PROXY UNKNOWN 2001:db8:: 64:ff9b::0.0.0.0 1000 3000\r\n")); + ASSERT_FALSE(parseStringExpectSuccess("PROXY UNKNOWN hot garbage\r\n")); + // The longest possible V1 header + ASSERT_FALSE( + parseStringExpectSuccess("PROXY UNKNOWN {} {} 65535 65535\r\n"_format(allFs, allFs))); +} + +struct TestV2Header { + std::string header; + std::string versionAndCommand; + std::string addressFamilyAndProtocol; + std::string length; + std::string firstAddr; + std::string secondAddr; + std::string metadata; + + std::string toString() const { + return "{}{}{}{}{}{}{}"_format(header, + versionAndCommand, + addressFamilyAndProtocol, + length, + firstAddr, + secondAddr, + metadata); + } +}; + +TEST(ProxyProtocolHeaderParser, MalformedV2Headers) { + // These strings contain null characters in them, so we need string literals. + using namespace std::string_literals; + + TestV2Header header; + + // Specify an invalid header. + header.header = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x48\x54\x0A"s; + parseStringExpectFailure(header.toString(), "header bytes invalid"); + + // Correct the header but break the version. + header.header = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"s; + header.versionAndCommand = "\x30"; + parseStringExpectFailure(header.toString(), "Invalid version"); + header.versionAndCommand = "\x22"; + parseStringExpectFailure(header.toString(), "Invalid version"); + // Correct the version/command but break the address family. + header.versionAndCommand = "\x21"; + header.addressFamilyAndProtocol = "\x40"; + parseStringExpectFailure(header.toString(), "Invalid address"); + header.addressFamilyAndProtocol = "\x23"; + parseStringExpectFailure(header.toString(), "Invalid protocol"); + + // TCP4 + // Set the length to 1. + header.addressFamilyAndProtocol = "\x11"; + header.length = "\x00\x01"s; + header.firstAddr = std::string(1, '\0'); + parseStringExpectFailure(header.toString(), "too short"); + // Set to the longest non-valid length (11). + header.length = "\x00\x0B"s; + header.firstAddr = std::string(6, '\0'); + header.secondAddr = std::string(5, '\0'); + parseStringExpectFailure(header.toString(), "too short"); + + // TCP6 + // Set the length to 1. + header.addressFamilyAndProtocol = "\x21"; + header.length = "\x00\x01"s; + header.firstAddr = std::string(1, '\0'); + header.secondAddr = ""; + parseStringExpectFailure(header.toString(), "too short"); + // Set to the longest non-valid length (35). + header.length = "\x00\x23"s; + header.firstAddr = std::string(18, '\0'); + header.secondAddr = std::string(17, '\0'); + parseStringExpectFailure(header.toString(), "too short"); + + // UNIX + // Set the length to 1. + header.addressFamilyAndProtocol = "\x31"; + header.length = "\x00\x01"s; + header.firstAddr = std::string(1, '\0'); + header.secondAddr = ""; + parseStringExpectFailure(header.toString(), "too short"); + // Set to the longest non-valid length (35). + header.length = "\x00\xD7"s; + header.firstAddr = std::string(108, '\0'); + header.secondAddr = std::string(107, '\0'); + parseStringExpectFailure(header.toString(), "too short"); +} + +template <typename SockAddrUn = sockaddr_un> +std::pair<SockAddrUn, SockAddrUn> createTestSockAddrUn(std::string srcPath, std::string dstPath) { + std::pair<SockAddrUn, SockAddrUn> addrs; + addrs.first.sun_family = addrs.second.sun_family = AF_UNIX; + + memcpy(addrs.first.sun_path, + srcPath.c_str(), + std::min(srcPath.size(), sizeof(SockAddrUn::sun_path))); + memcpy(addrs.second.sun_path, + dstPath.c_str(), + std::min(dstPath.size(), sizeof(SockAddrUn::sun_path))); + + return addrs; +} + +std::string createTestUnixPathString(StringData path) { + std::string out{path}; + out.resize(proxy_protocol_details::kMaxUnixPathLength); + return out; +} + +TEST(ProxyProtocolHeaderParser, WellFormedV2Headers) { + // These strings contain null characters in them, so we need string literals. + using namespace std::string_literals; + + TestV2Header header; + // TCP4 + header.header = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"s; + header.versionAndCommand = "\x21"s; + header.addressFamilyAndProtocol = "\x12"s; + header.length = "\x00\x0C"s; + header.firstAddr = "\x0C\x22\x38\x4e\xac\x10"s; + header.secondAddr = "\x00\x01\x04\xd2\x1f\x90"s; + ASSERT_THAT(*parseStringExpectSuccess(header.toString()), + ProxiedEndpointsAre(Eq(SockAddr::create("12.34.56.78", 1234, AF_INET)), + Eq(SockAddr::create("172.16.0.1", 8080, AF_INET)))); + + // We correctly ignore data at the end. + header.length = "\x00\x0E"s; + header.metadata = std::string(2, '\0'); + ASSERT_THAT(*parseStringExpectSuccess(header.toString()), + ProxiedEndpointsAre(Eq(SockAddr::create("12.34.56.78", 1234, AF_INET)), + Eq(SockAddr::create("172.16.0.1", 8080, AF_INET)))); + + // If this is a local connection, return nothing. + header.versionAndCommand = "\x20"s; + ASSERT_FALSE(parseStringExpectSuccess(header.toString())); + + // TCP6 + header.versionAndCommand = "\x21"s; + header.addressFamilyAndProtocol = "\x21"s; + header.length = "\x00\x24"s; + header.firstAddr = "\x20\x1\xd\xb8\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x64"s; + header.secondAddr = "\xff\x9b\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x4\xd2\x1f\x90"s; + header.metadata = ""; + ASSERT_THAT(*parseStringExpectSuccess(header.toString()), + ProxiedEndpointsAre(Eq(SockAddr::create("2001:db8::", 1234, AF_INET6)), + Eq(SockAddr::create("64:ff9b::0.0.0.0", 8080, AF_INET6)))); + + // We correctly ignore data at the end. + header.length = "\x00\x55"s; + header.metadata = std::string(49, '\1'); + ASSERT_THAT(*parseStringExpectSuccess(header.toString()), + ProxiedEndpointsAre(Eq(SockAddr::create("2001:db8::", 1234, AF_INET6)), + Eq(SockAddr::create("64:ff9b::0.0.0.0", 8080, AF_INET6)))); + + // If this is a local connection, return nothing. + header.versionAndCommand = "\x20"s; + ASSERT_FALSE(parseStringExpectSuccess(header.toString())); + + // UNIX + header.versionAndCommand = "\x21"s; + header.addressFamilyAndProtocol = "\x31"s; + header.length = "\x00\xD8"s; + const std::string srcPath(sizeof(sockaddr_un::sun_path) / 2, '\1'); + const std::string dstPath(sizeof(sockaddr_un::sun_path) - 1, '\2'); + header.firstAddr = createTestUnixPathString(srcPath); + header.secondAddr = createTestUnixPathString(dstPath); + header.metadata = ""; + const auto addrs = createTestSockAddrUn(srcPath, dstPath); + ASSERT_THAT(*parseStringExpectSuccess(header.toString()), + ProxiedEndpointsAre(Eq(SockAddr((sockaddr*)&addrs.first, sizeof(sockaddr_un))), + Eq(SockAddr((sockaddr*)&addrs.second, sizeof(sockaddr_un))))); + + // We correctly ignore data at the end. + header.length = "\x01\x08"; + header.metadata = std::string(48, '\0'); + // Extraneous data at the end is correctly ingested and ignored. + ASSERT_THAT(*parseStringExpectSuccess(header.toString()), + ProxiedEndpointsAre(Eq(SockAddr((sockaddr*)&addrs.first, sizeof(sockaddr_un))), + Eq(SockAddr((sockaddr*)&addrs.second, sizeof(sockaddr_un))))); + + // If this is a local connection, return nothing. + header.versionAndCommand = "\x20"s; + ASSERT_FALSE(parseStringExpectSuccess(header.toString())); + + // The family is not parsed if the connection is local. + header.addressFamilyAndProtocol = "\xA2"; + ASSERT_FALSE(parseStringExpectSuccess(header.toString())); +} + +struct TestSockAddrUnLinux { + sa_family_t sun_family; + char sun_path[108]; + + friend bool operator==(const TestSockAddrUnLinux& a, const TestSockAddrUnLinux& b) { + return a.sun_family == b.sun_family && !memcmp(a.sun_path, b.sun_path, sizeof(sun_path)); + } +}; + +TEST(ProxyProtocolHeaderParser, LinuxSockAddrUnParsing) { + // Test the parser against a Linux-like sockaddr_un + { + const std::string srcPath(sizeof(TestSockAddrUnLinux::sun_path) / 2, '\1'); + const std::string dstPath(sizeof(TestSockAddrUnLinux::sun_path) - 1, '\2'); + const auto addrs = createTestSockAddrUn<TestSockAddrUnLinux>(srcPath, dstPath); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>( + createTestUnixPathString(srcPath)), + Eq(addrs.first)); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>( + createTestUnixPathString(dstPath)), + Eq(addrs.second)); + } + + { + const std::string srcPath(sizeof(TestSockAddrUnLinux::sun_path), '\0'); + const std::string dstPath(1, '\0'); + const auto addrs = createTestSockAddrUn<TestSockAddrUnLinux>(srcPath, dstPath); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>( + createTestUnixPathString(srcPath)), + Eq(addrs.first)); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>( + createTestUnixPathString(dstPath)), + Eq(addrs.second)); + } +} + +struct TestSockAddrUnMac { + unsigned char sun_len; + sa_family_t sun_family; + char sun_path[104]; + + friend bool operator==(const TestSockAddrUnMac& a, const TestSockAddrUnMac& b) { + return a.sun_family == b.sun_family && a.sun_len == b.sun_len && + !memcmp(a.sun_path, b.sun_path, a.sun_len); + } +}; + +TEST(ProxyProtocolHeaderParser, MacSockAddrUnParsing) { + // Test the parser against a Mac-like sockaddr_un + { + const std::string srcPath(sizeof(TestSockAddrUnMac::sun_path) / 2, '\1'); + const std::string dstPath(sizeof(TestSockAddrUnMac::sun_path) - 1, '\2'); + const auto addrs = createTestSockAddrUn<TestSockAddrUnMac>(srcPath, dstPath); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>( + createTestUnixPathString(srcPath)), + Eq(addrs.first)); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>( + createTestUnixPathString(dstPath)), + Eq(addrs.second)); + } + + { + const std::string srcPath(1, '\1'); + const std::string dstPath = ""; + const auto addrs = createTestSockAddrUn<TestSockAddrUnMac>(srcPath, dstPath); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>( + createTestUnixPathString(srcPath)), + Eq(addrs.first)); + ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>( + createTestUnixPathString(dstPath)), + Eq(addrs.second)); + } + + try { + const std::string srcPath(proxy_protocol_details::kMaxUnixPathLength - 1, '\1'); + const std::string dstPath(proxy_protocol_details::kMaxUnixPathLength - 1, '\2'); + proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>( + createTestUnixPathString(srcPath) + createTestUnixPathString(dstPath)); + FAIL("Expected to throw"); + } catch (const DBException& ex) { + ASSERT_THAT(ex.toStatus(), + StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("longer than system"))); + } + + try { + const std::string srcPath(sizeof(TestSockAddrUnMac::sun_path), '\1'); + const std::string dstPath(sizeof(TestSockAddrUnMac::sun_path), '\2'); + proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>( + createTestUnixPathString(srcPath) + createTestUnixPathString(dstPath)); + FAIL("Expected to throw"); + } catch (const DBException& ex) { + ASSERT_THAT(ex.toStatus(), + StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("longer than system"))); + } +} + +} // namespace +} // namespace mongo::transport |